15 Commits

Author SHA1 Message Date
6f55505444 feat(writer): 🎉 Enhance model name generation and formatting
All checks were successful
CI / Test (1.24) (push) Successful in -27m27s
CI / Test (1.25) (push) Successful in -27m17s
CI / Lint (push) Successful in -27m27s
CI / Build (push) Successful in -27m38s
Release / Build and Release (push) Successful in -27m24s
Integration Tests / Integration Tests (push) Successful in -27m16s
* Update model name generation to include schema name.
* Add gofmt execution after writing output files.
* Refactor relationship field naming to include schema.
* Update tests to reflect changes in model names and relationships.
2026-01-10 18:28:41 +02:00
e0e7b64c69 feat(writer): 🎉 Resolve field name collisions with methods
All checks were successful
CI / Test (1.24) (push) Successful in -27m21s
CI / Test (1.25) (push) Successful in -27m12s
CI / Build (push) Successful in -27m37s
CI / Lint (push) Successful in -27m26s
Release / Build and Release (push) Successful in -27m25s
Integration Tests / Integration Tests (push) Successful in -27m20s
* Implement field name collision resolution in model generation.
* Add tests to verify renaming of fields that conflict with generated method names.
* Ensure primary key type safety in UpdateID method.
2026-01-10 17:54:33 +02:00
4181cb1fbd feat(writer): 🎉 Enhance relationship field naming and uniqueness
All checks were successful
CI / Test (1.24) (push) Successful in -27m15s
CI / Test (1.25) (push) Successful in -27m10s
CI / Build (push) Successful in -27m38s
CI / Lint (push) Successful in -27m25s
Release / Build and Release (push) Successful in -27m27s
Integration Tests / Integration Tests (push) Successful in -27m18s
* Update relationship field naming conventions for has-one and has-many relationships.
* Implement logic to ensure unique field names by tracking used names.
* Add tests to verify new naming conventions and uniqueness constraints.
2026-01-10 17:45:13 +02:00
120ffc6a5a feat(writer): 🎉 Update relationship field naming convention
All checks were successful
CI / Test (1.24) (push) Successful in -27m26s
CI / Test (1.25) (push) Successful in -27m14s
CI / Lint (push) Successful in -27m27s
CI / Build (push) Successful in -27m36s
Release / Build and Release (push) Successful in -27m22s
Integration Tests / Integration Tests (push) Successful in -27m17s
* Refactor generateRelationshipFieldName to use foreign key columns for unique naming.
* Add test for multiple references to the same table to ensure unique relationship field names.
* Update existing tests to reflect new naming convention.
2026-01-10 13:49:54 +02:00
b20ad35485 feat(writer): 🎉 Add sanitization for struct tag values
All checks were successful
CI / Test (1.24) (push) Successful in -27m25s
CI / Test (1.25) (push) Successful in -27m17s
CI / Build (push) Successful in -27m36s
CI / Lint (push) Successful in -27m23s
Release / Build and Release (push) Successful in -27m21s
Integration Tests / Integration Tests (push) Successful in -27m16s
* Implement SanitizeStructTagValue function to clean identifiers for struct tags.
* Update model data generation to use sanitized column names.
* Ensure safe handling of backticks in column names and types across writers.
2026-01-10 13:42:25 +02:00
f258f8baeb feat(writer): 🎉 Add filename sanitization for DBML identifiers
All checks were successful
CI / Test (1.24) (push) Successful in -27m23s
CI / Test (1.25) (push) Successful in -27m16s
CI / Build (push) Successful in -27m40s
CI / Lint (push) Successful in -27m29s
Release / Build and Release (push) Successful in -27m21s
Integration Tests / Integration Tests (push) Successful in -27m17s
* Implement SanitizeFilename function to clean identifiers
* Remove quotes, comments, and invalid characters from filenames
* Update filename generation in writers to use sanitized names
2026-01-10 13:32:33 +02:00
6388daba56 feat(reader): 🎉 Add support for multi-file DBML loading
All checks were successful
CI / Test (1.24) (push) Successful in -27m13s
CI / Test (1.25) (push) Successful in -27m5s
CI / Build (push) Successful in -27m16s
CI / Lint (push) Successful in -27m0s
Integration Tests / Integration Tests (push) Successful in -27m14s
Release / Build and Release (push) Successful in -25m52s
* Implement directory reading for DBML files.
* Merge schemas and tables from multiple files.
* Add tests for multi-file loading and merging behavior.
* Enhance file discovery and sorting logic.
2026-01-10 13:17:30 +02:00
f6c3f2b460 feat(bun): 🎉 Enhance nullability handling in column parsing
All checks were successful
CI / Test (1.24) (push) Successful in -27m40s
CI / Test (1.25) (push) Successful in -27m32s
CI / Lint (push) Successful in -27m46s
CI / Build (push) Successful in -27m56s
Integration Tests / Integration Tests (push) Successful in -27m40s
* Introduce explicit nullability markers in column tags.
* Update logic to infer nullability based on Go types when no markers are present.
* Ensure correct tags are generated for nullable and non-nullable fields.
2026-01-04 22:11:44 +02:00
156e655571 chore(ci): 🎉 Install PostgreSQL client for integration tests
Some checks failed
CI / Test (1.24) (push) Successful in -27m31s
CI / Lint (push) Successful in -27m52s
CI / Test (1.25) (push) Successful in -27m35s
CI / Build (push) Successful in -28m5s
Integration Tests / Integration Tests (push) Failing after -27m44s
2026-01-04 22:04:20 +02:00
b57e1ba304 feat(cmd): 🎉 Add split command for schema extraction
Some checks failed
CI / Test (1.24) (push) Successful in -27m40s
CI / Test (1.25) (push) Successful in -27m39s
CI / Build (push) Successful in -28m9s
CI / Lint (push) Successful in -27m56s
Integration Tests / Integration Tests (push) Failing after -28m11s
Release / Build and Release (push) Successful in -26m13s
- Introduce 'split' command to extract selected tables and schemas.
- Supports various input and output formats.
- Allows filtering of schemas and tables during extraction.
2026-01-04 22:01:29 +02:00
19fba62f1b feat(ui): 🎉 Add GUID field to column, database, schema, and table editors
Some checks failed
CI / Test (1.24) (push) Successful in -27m38s
CI / Lint (push) Successful in -27m58s
CI / Test (1.25) (push) Successful in -26m52s
CI / Build (push) Successful in -28m9s
Integration Tests / Integration Tests (push) Failing after -28m11s
2026-01-04 20:00:18 +02:00
b4ff4334cc feat(models): 🎉 Add GUID field to various models
Some checks failed
CI / Lint (push) Successful in -27m53s
CI / Test (1.24) (push) Successful in -27m31s
CI / Build (push) Successful in -28m13s
CI / Test (1.25) (push) Failing after 1m11s
Integration Tests / Integration Tests (push) Failing after -28m15s
* Introduced GUID field to Database, Domain, DomainTable, Schema, Table, View, Sequence, Column, Index, Relationship, Constraint, Enum, and Script models.
* Updated initialization functions to assign new GUIDs using uuid package.
* Enhanced DCTX reader and writer to utilize GUIDs from models where available.
2026-01-04 19:53:17 +02:00
5d9b00c8f2 feat(ui): 🎉 Add import and merge database feature
Some checks failed
CI / Lint (push) Successful in -27m51s
CI / Test (1.24) (push) Successful in -27m35s
CI / Test (1.25) (push) Failing after 1m5s
Integration Tests / Integration Tests (push) Failing after -28m14s
CI / Build (push) Successful in -28m13s
- Introduce a new screen for importing and merging database schemas.
- Implement merge logic to combine schemas, tables, columns, and other objects.
- Add options to skip specific object types during the merge process.
- Update main menu to include the new import and merge option.
2026-01-04 19:31:28 +02:00
debf351c48 fix(ui): 🐛 Simplify keyboard shortcut handling in load/save screens
Some checks failed
CI / Test (1.24) (push) Successful in -27m35s
CI / Test (1.25) (push) Failing after 1m3s
CI / Lint (push) Successful in -27m26s
CI / Build (push) Successful in -28m10s
Integration Tests / Integration Tests (push) Failing after 1m1s
2026-01-04 18:41:59 +02:00
d87d657275 feat(ui): 🎨 Add user interface documentation and screenshots
Some checks failed
CI / Test (1.25) (push) Failing after 57s
CI / Build (push) Successful in 23s
CI / Lint (push) Failing after -27m11s
CI / Test (1.24) (push) Successful in -26m25s
Integration Tests / Integration Tests (push) Failing after 1m0s
- Document interactive terminal-based UI features
- Include screenshots for main screen, table view, and column editing
2026-01-04 18:39:13 +02:00
41 changed files with 3646 additions and 137 deletions

View File

@@ -46,6 +46,11 @@ jobs:
- name: Download dependencies - name: Download dependencies
run: go mod download run: go mod download
- name: Install PostgreSQL client
run: |
sudo apt-get update
sudo apt-get install -y postgresql-client
- name: Initialize test database - name: Initialize test database
env: env:
PGPASSWORD: relspec_test_password PGPASSWORD: relspec_test_password

View File

@@ -85,6 +85,29 @@ RelSpec includes a powerful schema validation and linting tool:
## Use of AI ## Use of AI
[Rules and use of AI](./AI_USE.md) [Rules and use of AI](./AI_USE.md)
## User Interface
RelSpec provides an interactive terminal-based user interface for managing and editing database schemas. The UI allows you to:
- **Browse Databases** - Navigate through your database structure with an intuitive menu system
- **Edit Schemas** - Create, modify, and organize database schemas
- **Manage Tables** - Add, update, or delete tables with full control over structure
- **Configure Columns** - Define column properties, data types, constraints, and relationships
- **Interactive Editing** - Real-time validation and feedback as you make changes
The interface supports multiple input formats, making it easy to load, edit, and save your database definitions in various formats.
<p align="center" width="100%">
<img src="./assets/image/screenshots/main_screen.jpg">
</p>
<p align="center" width="100%">
<img src="./assets/image/screenshots/table_view.jpg">
</p>
<p align="center" width="100%">
<img src="./assets/image/screenshots/edit_column.jpg">
</p>
## Installation ## Installation
```bash ```bash
@@ -95,6 +118,55 @@ go install -v git.warky.dev/wdevs/relspecgo/cmd/relspec@latest
## Usage ## Usage
### Interactive Schema Editor
```bash
# Launch interactive editor with a DBML schema
relspec edit --from dbml --from-path schema.dbml --to dbml --to-path schema.dbml
# Edit PostgreSQL database in place
relspec edit --from pgsql --from-conn "postgres://user:pass@localhost/mydb" \
--to pgsql --to-conn "postgres://user:pass@localhost/mydb"
# Edit JSON schema and save as GORM models
relspec edit --from json --from-path db.json --to gorm --to-path models/
```
The `edit` command launches an interactive terminal user interface where you can:
- Browse and navigate your database structure
- Create, modify, and delete schemas, tables, and columns
- Configure column properties, constraints, and relationships
- Save changes to various formats
- Import and merge schemas from other databases
### Schema Merging
```bash
# Merge two JSON schemas (additive merge - adds missing items only)
relspec merge --target json --target-path base.json \
--source json --source-path additions.json \
--output json --output-path merged.json
# Merge PostgreSQL database into JSON, skipping specific tables
relspec merge --target json --target-path current.json \
--source pgsql --source-conn "postgres://user:pass@localhost/source_db" \
--output json --output-path updated.json \
--skip-tables "audit_log,temp_tables"
# Cross-format merge (DBML + YAML → JSON)
relspec merge --target dbml --target-path base.dbml \
--source yaml --source-path additions.yaml \
--output json --output-path result.json \
--skip-relations --skip-views
```
The `merge` command combines two database schemas additively:
- Adds missing schemas, tables, columns, and other objects
- Never modifies or deletes existing items (safe operation)
- Supports selective merging with skip options (domains, relations, enums, views, sequences, specific tables)
- Works across any combination of supported formats
- Perfect for integrating multiple schema definitions or applying patches
### Schema Conversion ### Schema Conversion
```bash ```bash

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

433
cmd/relspec/merge.go Normal file
View File

@@ -0,0 +1,433 @@
package main
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/spf13/cobra"
"git.warky.dev/wdevs/relspecgo/pkg/merge"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
"git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
"git.warky.dev/wdevs/relspecgo/pkg/readers/dbml"
"git.warky.dev/wdevs/relspecgo/pkg/readers/dctx"
"git.warky.dev/wdevs/relspecgo/pkg/readers/drawdb"
"git.warky.dev/wdevs/relspecgo/pkg/readers/drizzle"
"git.warky.dev/wdevs/relspecgo/pkg/readers/gorm"
"git.warky.dev/wdevs/relspecgo/pkg/readers/graphql"
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
wbun "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
wdbml "git.warky.dev/wdevs/relspecgo/pkg/writers/dbml"
wdctx "git.warky.dev/wdevs/relspecgo/pkg/writers/dctx"
wdrawdb "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb"
wdrizzle "git.warky.dev/wdevs/relspecgo/pkg/writers/drizzle"
wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
wgraphql "git.warky.dev/wdevs/relspecgo/pkg/writers/graphql"
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
)
var (
mergeTargetType string
mergeTargetPath string
mergeTargetConn string
mergeSourceType string
mergeSourcePath string
mergeSourceConn string
mergeOutputType string
mergeOutputPath string
mergeOutputConn string
mergeSkipDomains bool
mergeSkipRelations bool
mergeSkipEnums bool
mergeSkipViews bool
mergeSkipSequences bool
mergeSkipTables string // Comma-separated table names to skip
mergeVerbose bool
)
var mergeCmd = &cobra.Command{
Use: "merge",
Short: "Merge database schemas (additive only - adds missing items)",
Long: `Merge one database schema into another. Performs additive merging only:
adds missing schemas, tables, columns, and other objects without modifying
or deleting existing items.
The target database is loaded first, then the source database is merged into it.
The result can be saved to a new format or updated in place.
Examples:
# Merge two JSON schemas
relspec merge --target json --target-path base.json \
--source json --source-path additional.json \
--output json --output-path merged.json
# Merge from PostgreSQL into JSON
relspec merge --target json --target-path mydb.json \
--source pgsql --source-conn "postgres://user:pass@localhost/source_db" \
--output json --output-path combined.json
# Merge DBML and YAML, skip relations
relspec merge --target dbml --target-path schema.dbml \
--source yaml --source-path tables.yaml \
--output dbml --output-path merged.dbml \
--skip-relations
# Merge and save back to target format
relspec merge --target json --target-path base.json \
--source json --source-path patch.json \
--output json --output-path base.json`,
RunE: runMerge,
}
func init() {
// Target database flags
mergeCmd.Flags().StringVar(&mergeTargetType, "target", "", "Target format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
mergeCmd.Flags().StringVar(&mergeTargetPath, "target-path", "", "Target file path (required for file-based formats)")
mergeCmd.Flags().StringVar(&mergeTargetConn, "target-conn", "", "Target connection string (required for pgsql)")
// Source database flags
mergeCmd.Flags().StringVar(&mergeSourceType, "source", "", "Source format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
mergeCmd.Flags().StringVar(&mergeSourcePath, "source-path", "", "Source file path (required for file-based formats)")
mergeCmd.Flags().StringVar(&mergeSourceConn, "source-conn", "", "Source connection string (required for pgsql)")
// Output flags
mergeCmd.Flags().StringVar(&mergeOutputType, "output", "", "Output format (required): dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql")
mergeCmd.Flags().StringVar(&mergeOutputPath, "output-path", "", "Output file path (required for file-based formats)")
mergeCmd.Flags().StringVar(&mergeOutputConn, "output-conn", "", "Output connection string (for pgsql)")
// Merge options
mergeCmd.Flags().BoolVar(&mergeSkipDomains, "skip-domains", false, "Skip domains during merge")
mergeCmd.Flags().BoolVar(&mergeSkipRelations, "skip-relations", false, "Skip relations during merge")
mergeCmd.Flags().BoolVar(&mergeSkipEnums, "skip-enums", false, "Skip enums during merge")
mergeCmd.Flags().BoolVar(&mergeSkipViews, "skip-views", false, "Skip views during merge")
mergeCmd.Flags().BoolVar(&mergeSkipSequences, "skip-sequences", false, "Skip sequences during merge")
mergeCmd.Flags().StringVar(&mergeSkipTables, "skip-tables", "", "Comma-separated list of table names to skip during merge")
mergeCmd.Flags().BoolVar(&mergeVerbose, "verbose", false, "Show verbose output")
}
func runMerge(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, "\n=== RelSpec Merge ===\n")
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
// Validate required flags
if mergeTargetType == "" {
return fmt.Errorf("--target format is required")
}
if mergeSourceType == "" {
return fmt.Errorf("--source format is required")
}
if mergeOutputType == "" {
return fmt.Errorf("--output format is required")
}
// Validate and expand file paths
if mergeTargetType != "pgsql" {
if mergeTargetPath == "" {
return fmt.Errorf("--target-path is required for %s format", mergeTargetType)
}
mergeTargetPath = expandPath(mergeTargetPath)
} else if mergeTargetConn == "" {
return fmt.Errorf("--target-conn is required for pgsql format")
}
if mergeSourceType != "pgsql" {
if mergeSourcePath == "" {
return fmt.Errorf("--source-path is required for %s format", mergeSourceType)
}
mergeSourcePath = expandPath(mergeSourcePath)
} else if mergeSourceConn == "" {
return fmt.Errorf("--source-conn is required for pgsql format")
}
if mergeOutputType != "pgsql" {
if mergeOutputPath == "" {
return fmt.Errorf("--output-path is required for %s format", mergeOutputType)
}
mergeOutputPath = expandPath(mergeOutputPath)
}
// Step 1: Read target database
fmt.Fprintf(os.Stderr, "[1/3] Reading target database...\n")
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeTargetType)
if mergeTargetPath != "" {
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeTargetPath)
}
if mergeTargetConn != "" {
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(mergeTargetConn))
}
targetDB, err := readDatabaseForMerge(mergeTargetType, mergeTargetPath, mergeTargetConn, "Target")
if err != nil {
return fmt.Errorf("failed to read target database: %w", err)
}
fmt.Fprintf(os.Stderr, " ✓ Successfully read target database '%s'\n", targetDB.Name)
printDatabaseStats(targetDB)
// Step 2: Read source database
fmt.Fprintf(os.Stderr, "\n[2/3] Reading source database...\n")
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeSourceType)
if mergeSourcePath != "" {
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeSourcePath)
}
if mergeSourceConn != "" {
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(mergeSourceConn))
}
sourceDB, err := readDatabaseForMerge(mergeSourceType, mergeSourcePath, mergeSourceConn, "Source")
if err != nil {
return fmt.Errorf("failed to read source database: %w", err)
}
fmt.Fprintf(os.Stderr, " ✓ Successfully read source database '%s'\n", sourceDB.Name)
printDatabaseStats(sourceDB)
// Step 3: Merge databases
fmt.Fprintf(os.Stderr, "\n[3/3] Merging databases...\n")
opts := &merge.MergeOptions{
SkipDomains: mergeSkipDomains,
SkipRelations: mergeSkipRelations,
SkipEnums: mergeSkipEnums,
SkipViews: mergeSkipViews,
SkipSequences: mergeSkipSequences,
}
// Parse skip-tables flag
if mergeSkipTables != "" {
opts.SkipTableNames = parseSkipTables(mergeSkipTables)
if len(opts.SkipTableNames) > 0 {
fmt.Fprintf(os.Stderr, " Skipping tables: %s\n", mergeSkipTables)
}
}
result := merge.MergeDatabases(targetDB, sourceDB, opts)
// Update timestamp
targetDB.UpdateDate()
// Print merge summary
fmt.Fprintf(os.Stderr, " ✓ Merge complete\n\n")
fmt.Fprintf(os.Stderr, "%s\n", merge.GetMergeSummary(result))
// Step 4: Write output
fmt.Fprintf(os.Stderr, "\n[4/4] Writing output...\n")
fmt.Fprintf(os.Stderr, " Format: %s\n", mergeOutputType)
if mergeOutputPath != "" {
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeOutputPath)
}
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, "", targetDB, "Output")
if err != nil {
return fmt.Errorf("failed to write output: %w", err)
}
fmt.Fprintf(os.Stderr, " ✓ Successfully written merged database\n")
fmt.Fprintf(os.Stderr, "\n=== Merge complete ===\n")
return nil
}
func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.Database, error) {
var reader readers.Reader
switch strings.ToLower(dbType) {
case "dbml":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DBML format", label)
}
reader = dbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "dctx":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DCTX format", label)
}
reader = dctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "drawdb":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for DrawDB format", label)
}
reader = drawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "graphql":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GraphQL format", label)
}
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "json":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for JSON format", label)
}
reader = json.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "yaml":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for YAML format", label)
}
reader = yaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "gorm":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for GORM format", label)
}
reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "bun":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Bun format", label)
}
reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "drizzle":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Drizzle format", label)
}
reader = drizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "prisma":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for Prisma format", label)
}
reader = prisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "typeorm":
if filePath == "" {
return nil, fmt.Errorf("%s: file path is required for TypeORM format", label)
}
reader = typeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "pgsql":
if connString == "" {
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
}
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
default:
return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType)
}
db, err := reader.ReadDatabase()
if err != nil {
return nil, err
}
return db, nil
}
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string) error {
var writer writers.Writer
switch strings.ToLower(dbType) {
case "dbml":
if filePath == "" {
return fmt.Errorf("%s: file path is required for DBML format", label)
}
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "dctx":
if filePath == "" {
return fmt.Errorf("%s: file path is required for DCTX format", label)
}
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "drawdb":
if filePath == "" {
return fmt.Errorf("%s: file path is required for DrawDB format", label)
}
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "graphql":
if filePath == "" {
return fmt.Errorf("%s: file path is required for GraphQL format", label)
}
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "json":
if filePath == "" {
return fmt.Errorf("%s: file path is required for JSON format", label)
}
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "yaml":
if filePath == "" {
return fmt.Errorf("%s: file path is required for YAML format", label)
}
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "gorm":
if filePath == "" {
return fmt.Errorf("%s: file path is required for GORM format", label)
}
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "bun":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Bun format", label)
}
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "drizzle":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Drizzle format", label)
}
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "prisma":
if filePath == "" {
return fmt.Errorf("%s: file path is required for Prisma format", label)
}
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "typeorm":
if filePath == "" {
return fmt.Errorf("%s: file path is required for TypeORM format", label)
}
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
case "pgsql":
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
default:
return fmt.Errorf("%s: unsupported format '%s'", label, dbType)
}
return writer.WriteDatabase(db)
}
func expandPath(path string) string {
if len(path) > 0 && path[0] == '~' {
home, err := os.UserHomeDir()
if err == nil {
return filepath.Join(home, path[1:])
}
}
return path
}
func printDatabaseStats(db *models.Database) {
totalTables := 0
totalColumns := 0
totalConstraints := 0
totalIndexes := 0
for _, schema := range db.Schemas {
totalTables += len(schema.Tables)
for _, table := range schema.Tables {
totalColumns += len(table.Columns)
totalConstraints += len(table.Constraints)
totalIndexes += len(table.Indexes)
}
}
fmt.Fprintf(os.Stderr, " Schemas: %d, Tables: %d, Columns: %d, Constraints: %d, Indexes: %d\n",
len(db.Schemas), totalTables, totalColumns, totalConstraints, totalIndexes)
}
func parseSkipTables(skipTablesStr string) map[string]bool {
skipTables := make(map[string]bool)
if skipTablesStr == "" {
return skipTables
}
// Split by comma and trim whitespace
parts := strings.Split(skipTablesStr, ",")
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
// Store in lowercase for case-insensitive matching
skipTables[strings.ToLower(trimmed)] = true
}
}
return skipTables
}

View File

@@ -22,4 +22,6 @@ func init() {
rootCmd.AddCommand(scriptsCmd) rootCmd.AddCommand(scriptsCmd)
rootCmd.AddCommand(templCmd) rootCmd.AddCommand(templCmd)
rootCmd.AddCommand(editCmd) rootCmd.AddCommand(editCmd)
rootCmd.AddCommand(mergeCmd)
rootCmd.AddCommand(splitCmd)
} }

318
cmd/relspec/split.go Normal file
View File

@@ -0,0 +1,318 @@
package main
import (
"fmt"
"os"
"strings"
"github.com/spf13/cobra"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
var (
splitSourceType string
splitSourcePath string
splitSourceConn string
splitTargetType string
splitTargetPath string
splitSchemas string
splitTables string
splitPackageName string
splitDatabaseName string
splitExcludeSchema string
splitExcludeTables string
)
var splitCmd = &cobra.Command{
Use: "split",
Short: "Split database schemas to extract selected tables into a separate database",
Long: `Extract selected schemas and tables from a database and write them to a separate output.
The split command allows you to:
- Select specific schemas to include in the output
- Select specific tables within schemas
- Exclude specific schemas or tables if preferred
- Export the selected subset to any supported format
Input formats:
- dbml: DBML schema files
- dctx: DCTX schema files
- drawdb: DrawDB JSON files
- graphql: GraphQL schema files (.graphql, SDL)
- json: JSON database schema
- yaml: YAML database schema
- gorm: GORM model files (Go, file or directory)
- bun: Bun model files (Go, file or directory)
- drizzle: Drizzle ORM schema files (TypeScript, file or directory)
- prisma: Prisma schema files (.prisma)
- typeorm: TypeORM entity files (TypeScript)
- pgsql: PostgreSQL database (live connection)
Output formats:
- dbml: DBML schema files
- dctx: DCTX schema files
- drawdb: DrawDB JSON files
- graphql: GraphQL schema files (.graphql, SDL)
- json: JSON database schema
- yaml: YAML database schema
- gorm: GORM model files (Go)
- bun: Bun model files (Go)
- drizzle: Drizzle ORM schema files (TypeScript)
- prisma: Prisma schema files (.prisma)
- typeorm: TypeORM entity files (TypeScript)
- pgsql: PostgreSQL SQL schema
Examples:
# Split specific schemas from DBML
relspec split --from dbml --from-path schema.dbml \
--schemas public,auth \
--to json --to-path subset.json
# Extract specific tables from PostgreSQL
relspec split --from pgsql \
--from-conn "postgres://user:pass@localhost:5432/mydb" \
--schemas public \
--tables users,orders,products \
--to dbml --to-path subset.dbml
# Exclude specific tables
relspec split --from json --from-path schema.json \
--exclude-tables "audit_log,system_config,temp_data" \
--to json --to-path public_schema.json
# Split and convert to GORM
relspec split --from json --from-path schema.json \
--tables "users,posts,comments" \
--to gorm --to-path models/ --package models \
--database-name MyAppDB
# Exclude specific schema and tables
relspec split --from pgsql \
--from-conn "postgres://user:pass@localhost/db" \
--exclude-schema pg_catalog,information_schema \
--exclude-tables "temp_users,debug_logs" \
--to json --to-path public_schema.json`,
RunE: runSplit,
}
func init() {
splitCmd.Flags().StringVar(&splitSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
splitCmd.Flags().StringVar(&splitSourcePath, "from-path", "", "Source file path (for file-based formats)")
splitCmd.Flags().StringVar(&splitSourceConn, "from-conn", "", "Source connection string (for database formats)")
splitCmd.Flags().StringVar(&splitTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
splitCmd.Flags().StringVar(&splitTargetPath, "to-path", "", "Target output path (file or directory)")
splitCmd.Flags().StringVar(&splitPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
splitCmd.Flags().StringVar(&splitDatabaseName, "database-name", "", "Override database name in output")
splitCmd.Flags().StringVar(&splitSchemas, "schemas", "", "Comma-separated list of schema names to include")
splitCmd.Flags().StringVar(&splitTables, "tables", "", "Comma-separated list of table names to include (case-insensitive)")
splitCmd.Flags().StringVar(&splitExcludeSchema, "exclude-schema", "", "Comma-separated list of schema names to exclude")
splitCmd.Flags().StringVar(&splitExcludeTables, "exclude-tables", "", "Comma-separated list of table names to exclude (case-insensitive)")
err := splitCmd.MarkFlagRequired("from")
if err != nil {
fmt.Fprintf(os.Stderr, "Error marking from flag as required: %v\n", err)
}
err = splitCmd.MarkFlagRequired("to")
if err != nil {
fmt.Fprintf(os.Stderr, "Error marking to flag as required: %v\n", err)
}
err = splitCmd.MarkFlagRequired("to-path")
if err != nil {
fmt.Fprintf(os.Stderr, "Error marking to-path flag as required: %v\n", err)
}
}
func runSplit(cmd *cobra.Command, args []string) error {
fmt.Fprintf(os.Stderr, "\n=== RelSpec Schema Split ===\n")
fmt.Fprintf(os.Stderr, "Started at: %s\n\n", getCurrentTimestamp())
// Read source database
fmt.Fprintf(os.Stderr, "[1/3] Reading source schema...\n")
fmt.Fprintf(os.Stderr, " Format: %s\n", splitSourceType)
if splitSourcePath != "" {
fmt.Fprintf(os.Stderr, " Path: %s\n", splitSourcePath)
}
if splitSourceConn != "" {
fmt.Fprintf(os.Stderr, " Conn: %s\n", maskPassword(splitSourceConn))
}
db, err := readDatabaseForConvert(splitSourceType, splitSourcePath, splitSourceConn)
if err != nil {
return fmt.Errorf("failed to read source: %w", err)
}
fmt.Fprintf(os.Stderr, " ✓ Successfully read database '%s'\n", db.Name)
fmt.Fprintf(os.Stderr, " Found: %d schema(s)\n", len(db.Schemas))
totalTables := 0
for _, schema := range db.Schemas {
totalTables += len(schema.Tables)
}
fmt.Fprintf(os.Stderr, " Found: %d table(s)\n\n", totalTables)
// Filter the database
fmt.Fprintf(os.Stderr, "[2/3] Filtering schemas and tables...\n")
filteredDB, err := filterDatabase(db)
if err != nil {
return fmt.Errorf("failed to filter database: %w", err)
}
if splitDatabaseName != "" {
filteredDB.Name = splitDatabaseName
}
filteredTables := 0
for _, schema := range filteredDB.Schemas {
filteredTables += len(schema.Tables)
}
fmt.Fprintf(os.Stderr, " ✓ Filtered to: %d schema(s), %d table(s)\n\n", len(filteredDB.Schemas), filteredTables)
// Write to target format
fmt.Fprintf(os.Stderr, "[3/3] Writing to target format...\n")
fmt.Fprintf(os.Stderr, " Format: %s\n", splitTargetType)
fmt.Fprintf(os.Stderr, " Output: %s\n", splitTargetPath)
if splitPackageName != "" {
fmt.Fprintf(os.Stderr, " Package: %s\n", splitPackageName)
}
err = writeDatabase(
filteredDB,
splitTargetType,
splitTargetPath,
splitPackageName,
"", // no schema filter for split
)
if err != nil {
return fmt.Errorf("failed to write output: %w", err)
}
fmt.Fprintf(os.Stderr, " ✓ Successfully written to '%s'\n\n", splitTargetPath)
fmt.Fprintf(os.Stderr, "=== Split Completed Successfully ===\n")
fmt.Fprintf(os.Stderr, "Completed at: %s\n\n", getCurrentTimestamp())
return nil
}
// filterDatabase filters the database based on provided criteria
func filterDatabase(db *models.Database) (*models.Database, error) {
filteredDB := &models.Database{
Name: db.Name,
Description: db.Description,
Comment: db.Comment,
DatabaseType: db.DatabaseType,
DatabaseVersion: db.DatabaseVersion,
SourceFormat: db.SourceFormat,
UpdatedAt: db.UpdatedAt,
GUID: db.GUID,
Schemas: []*models.Schema{},
Domains: db.Domains, // Keep domains for now
}
// Parse filter flags
includeSchemas := parseCommaSeparated(splitSchemas)
includeTables := parseCommaSeparated(splitTables)
excludeSchemas := parseCommaSeparated(splitExcludeSchema)
excludeTables := parseCommaSeparated(splitExcludeTables)
// Convert table names to lowercase for case-insensitive matching
includeTablesLower := make(map[string]bool)
for _, t := range includeTables {
includeTablesLower[strings.ToLower(t)] = true
}
excludeTablesLower := make(map[string]bool)
for _, t := range excludeTables {
excludeTablesLower[strings.ToLower(t)] = true
}
// Iterate through schemas
for _, schema := range db.Schemas {
// Check if schema should be excluded
if contains(excludeSchemas, schema.Name) {
continue
}
// Check if schema should be included
if len(includeSchemas) > 0 && !contains(includeSchemas, schema.Name) {
continue
}
// Create a copy of the schema with filtered tables
filteredSchema := &models.Schema{
Name: schema.Name,
Description: schema.Description,
Owner: schema.Owner,
Permissions: schema.Permissions,
Comment: schema.Comment,
Metadata: schema.Metadata,
Scripts: schema.Scripts,
Sequence: schema.Sequence,
Relations: schema.Relations,
Enums: schema.Enums,
UpdatedAt: schema.UpdatedAt,
GUID: schema.GUID,
Tables: []*models.Table{},
Views: schema.Views,
Sequences: schema.Sequences,
}
// Filter tables within the schema
for _, table := range schema.Tables {
tableLower := strings.ToLower(table.Name)
// Check if table should be excluded
if excludeTablesLower[tableLower] {
continue
}
// If specific tables are requested, only include those
if len(includeTablesLower) > 0 {
if !includeTablesLower[tableLower] {
continue
}
}
filteredSchema.Tables = append(filteredSchema.Tables, table)
}
// Only add schema if it has tables (unless no table filter was specified)
if len(filteredSchema.Tables) > 0 || (len(includeTablesLower) == 0 && len(excludeTablesLower) == 0) {
filteredDB.Schemas = append(filteredDB.Schemas, filteredSchema)
}
}
if len(filteredDB.Schemas) == 0 {
return nil, fmt.Errorf("no schemas matched the filter criteria")
}
return filteredDB, nil
}
// parseCommaSeparated parses a comma-separated string into a slice, trimming whitespace
func parseCommaSeparated(s string) []string {
if s == "" {
return []string{}
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
trimmed := strings.TrimSpace(p)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// contains checks if a string is in a slice
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

574
pkg/merge/merge.go Normal file
View File

@@ -0,0 +1,574 @@
// Package merge provides utilities for merging database schemas.
// It allows combining schemas from multiple sources while avoiding duplicates,
// supporting only additive operations (no deletion or modification of existing items).
package merge
import (
"fmt"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// MergeResult represents the result of a merge operation
type MergeResult struct {
SchemasAdded int
TablesAdded int
ColumnsAdded int
RelationsAdded int
DomainsAdded int
EnumsAdded int
ViewsAdded int
SequencesAdded int
}
// MergeOptions contains options for merge operations
type MergeOptions struct {
SkipDomains bool
SkipRelations bool
SkipEnums bool
SkipViews bool
SkipSequences bool
SkipTableNames map[string]bool // Tables to skip during merge (keyed by table name)
}
// MergeDatabases merges the source database into the target database.
// Only adds missing items; existing items are not modified.
func MergeDatabases(target, source *models.Database, opts *MergeOptions) *MergeResult {
if opts == nil {
opts = &MergeOptions{}
}
result := &MergeResult{}
if target == nil || source == nil {
return result
}
// Merge schemas and their contents
result.merge(target, source, opts)
return result
}
func (r *MergeResult) merge(target, source *models.Database, opts *MergeOptions) {
// Create maps of existing schemas for quick lookup
existingSchemas := make(map[string]*models.Schema)
for _, schema := range target.Schemas {
existingSchemas[schema.SQLName()] = schema
}
// Merge schemas
for _, srcSchema := range source.Schemas {
schemaName := srcSchema.SQLName()
if tgtSchema, exists := existingSchemas[schemaName]; exists {
// Schema exists, merge its contents
r.mergeSchemaContents(tgtSchema, srcSchema, opts)
} else {
// Schema doesn't exist, add it
newSchema := cloneSchema(srcSchema)
target.Schemas = append(target.Schemas, newSchema)
r.SchemasAdded++
}
}
// Merge domains if not skipped
if !opts.SkipDomains {
r.mergeDomains(target, source)
}
}
func (r *MergeResult) mergeSchemaContents(target, source *models.Schema, opts *MergeOptions) {
// Merge tables
r.mergeTables(target, source, opts)
// Merge views if not skipped
if !opts.SkipViews {
r.mergeViews(target, source)
}
// Merge sequences if not skipped
if !opts.SkipSequences {
r.mergeSequences(target, source)
}
// Merge enums if not skipped
if !opts.SkipEnums {
r.mergeEnums(target, source)
}
// Merge relations if not skipped
if !opts.SkipRelations {
r.mergeRelations(target, source)
}
}
func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema, opts *MergeOptions) {
// Create map of existing tables
existingTables := make(map[string]*models.Table)
for _, table := range schema.Tables {
existingTables[table.SQLName()] = table
}
// Merge tables
for _, srcTable := range source.Tables {
tableName := srcTable.SQLName()
// Skip if table is in the skip list (case-insensitive)
if opts != nil && opts.SkipTableNames != nil && opts.SkipTableNames[strings.ToLower(tableName)] {
continue
}
if tgtTable, exists := existingTables[tableName]; exists {
// Table exists, merge its columns
r.mergeColumns(tgtTable, srcTable)
} else {
// Table doesn't exist, add it
newTable := cloneTable(srcTable)
schema.Tables = append(schema.Tables, newTable)
r.TablesAdded++
// Count columns in the newly added table
r.ColumnsAdded += len(newTable.Columns)
}
}
}
func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table) {
// Create map of existing columns
existingColumns := make(map[string]*models.Column)
for colName := range table.Columns {
existingColumns[colName] = table.Columns[colName]
}
// Merge columns
for colName, srcCol := range srcTable.Columns {
if _, exists := existingColumns[colName]; !exists {
// Column doesn't exist, add it
newCol := cloneColumn(srcCol)
table.Columns[colName] = newCol
r.ColumnsAdded++
}
}
}
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
// Create map of existing views
existingViews := make(map[string]*models.View)
for _, view := range schema.Views {
existingViews[view.SQLName()] = view
}
// Merge views
for _, srcView := range source.Views {
viewName := srcView.SQLName()
if _, exists := existingViews[viewName]; !exists {
// View doesn't exist, add it
newView := cloneView(srcView)
schema.Views = append(schema.Views, newView)
r.ViewsAdded++
}
}
}
func (r *MergeResult) mergeSequences(schema *models.Schema, source *models.Schema) {
// Create map of existing sequences
existingSequences := make(map[string]*models.Sequence)
for _, seq := range schema.Sequences {
existingSequences[seq.SQLName()] = seq
}
// Merge sequences
for _, srcSeq := range source.Sequences {
seqName := srcSeq.SQLName()
if _, exists := existingSequences[seqName]; !exists {
// Sequence doesn't exist, add it
newSeq := cloneSequence(srcSeq)
schema.Sequences = append(schema.Sequences, newSeq)
r.SequencesAdded++
}
}
}
func (r *MergeResult) mergeEnums(schema *models.Schema, source *models.Schema) {
// Create map of existing enums
existingEnums := make(map[string]*models.Enum)
for _, enum := range schema.Enums {
existingEnums[enum.SQLName()] = enum
}
// Merge enums
for _, srcEnum := range source.Enums {
enumName := srcEnum.SQLName()
if _, exists := existingEnums[enumName]; !exists {
// Enum doesn't exist, add it
newEnum := cloneEnum(srcEnum)
schema.Enums = append(schema.Enums, newEnum)
r.EnumsAdded++
}
}
}
func (r *MergeResult) mergeRelations(schema *models.Schema, source *models.Schema) {
// Create map of existing relations
existingRelations := make(map[string]*models.Relationship)
for _, rel := range schema.Relations {
existingRelations[rel.SQLName()] = rel
}
// Merge relations
for _, srcRel := range source.Relations {
if _, exists := existingRelations[srcRel.SQLName()]; !exists {
// Relation doesn't exist, add it
newRel := cloneRelation(srcRel)
schema.Relations = append(schema.Relations, newRel)
r.RelationsAdded++
}
}
}
func (r *MergeResult) mergeDomains(target *models.Database, source *models.Database) {
// Create map of existing domains
existingDomains := make(map[string]*models.Domain)
for _, domain := range target.Domains {
existingDomains[domain.SQLName()] = domain
}
// Merge domains
for _, srcDomain := range source.Domains {
domainName := srcDomain.SQLName()
if _, exists := existingDomains[domainName]; !exists {
// Domain doesn't exist, add it
newDomain := cloneDomain(srcDomain)
target.Domains = append(target.Domains, newDomain)
r.DomainsAdded++
}
}
}
// Clone functions to create deep copies of models
func cloneSchema(schema *models.Schema) *models.Schema {
if schema == nil {
return nil
}
newSchema := &models.Schema{
Name: schema.Name,
Description: schema.Description,
Owner: schema.Owner,
Comment: schema.Comment,
Sequence: schema.Sequence,
UpdatedAt: schema.UpdatedAt,
Tables: make([]*models.Table, 0),
Views: make([]*models.View, 0),
Sequences: make([]*models.Sequence, 0),
Enums: make([]*models.Enum, 0),
Relations: make([]*models.Relationship, 0),
}
if schema.Permissions != nil {
newSchema.Permissions = make(map[string]string)
for k, v := range schema.Permissions {
newSchema.Permissions[k] = v
}
}
if schema.Metadata != nil {
newSchema.Metadata = make(map[string]interface{})
for k, v := range schema.Metadata {
newSchema.Metadata[k] = v
}
}
if schema.Scripts != nil {
newSchema.Scripts = make([]*models.Script, len(schema.Scripts))
copy(newSchema.Scripts, schema.Scripts)
}
// Clone tables
for _, table := range schema.Tables {
newSchema.Tables = append(newSchema.Tables, cloneTable(table))
}
// Clone views
for _, view := range schema.Views {
newSchema.Views = append(newSchema.Views, cloneView(view))
}
// Clone sequences
for _, seq := range schema.Sequences {
newSchema.Sequences = append(newSchema.Sequences, cloneSequence(seq))
}
// Clone enums
for _, enum := range schema.Enums {
newSchema.Enums = append(newSchema.Enums, cloneEnum(enum))
}
// Clone relations
for _, rel := range schema.Relations {
newSchema.Relations = append(newSchema.Relations, cloneRelation(rel))
}
return newSchema
}
func cloneTable(table *models.Table) *models.Table {
if table == nil {
return nil
}
newTable := &models.Table{
Name: table.Name,
Description: table.Description,
Schema: table.Schema,
Comment: table.Comment,
Sequence: table.Sequence,
UpdatedAt: table.UpdatedAt,
Columns: make(map[string]*models.Column),
Constraints: make(map[string]*models.Constraint),
Indexes: make(map[string]*models.Index),
}
if table.Metadata != nil {
newTable.Metadata = make(map[string]interface{})
for k, v := range table.Metadata {
newTable.Metadata[k] = v
}
}
// Clone columns
for colName, col := range table.Columns {
newTable.Columns[colName] = cloneColumn(col)
}
// Clone constraints
for constName, constraint := range table.Constraints {
newTable.Constraints[constName] = cloneConstraint(constraint)
}
// Clone indexes
for idxName, index := range table.Indexes {
newTable.Indexes[idxName] = cloneIndex(index)
}
return newTable
}
func cloneColumn(col *models.Column) *models.Column {
if col == nil {
return nil
}
newCol := &models.Column{
Name: col.Name,
Type: col.Type,
Description: col.Description,
Comment: col.Comment,
IsPrimaryKey: col.IsPrimaryKey,
NotNull: col.NotNull,
Default: col.Default,
Precision: col.Precision,
Scale: col.Scale,
Length: col.Length,
Sequence: col.Sequence,
AutoIncrement: col.AutoIncrement,
Collation: col.Collation,
}
return newCol
}
func cloneConstraint(constraint *models.Constraint) *models.Constraint {
if constraint == nil {
return nil
}
newConstraint := &models.Constraint{
Type: constraint.Type,
Columns: make([]string, len(constraint.Columns)),
ReferencedTable: constraint.ReferencedTable,
ReferencedSchema: constraint.ReferencedSchema,
ReferencedColumns: make([]string, len(constraint.ReferencedColumns)),
OnUpdate: constraint.OnUpdate,
OnDelete: constraint.OnDelete,
Expression: constraint.Expression,
Name: constraint.Name,
Deferrable: constraint.Deferrable,
InitiallyDeferred: constraint.InitiallyDeferred,
Sequence: constraint.Sequence,
}
copy(newConstraint.Columns, constraint.Columns)
copy(newConstraint.ReferencedColumns, constraint.ReferencedColumns)
return newConstraint
}
func cloneIndex(index *models.Index) *models.Index {
if index == nil {
return nil
}
newIndex := &models.Index{
Name: index.Name,
Description: index.Description,
Table: index.Table,
Schema: index.Schema,
Columns: make([]string, len(index.Columns)),
Unique: index.Unique,
Type: index.Type,
Where: index.Where,
Concurrent: index.Concurrent,
Include: make([]string, len(index.Include)),
Comment: index.Comment,
Sequence: index.Sequence,
}
copy(newIndex.Columns, index.Columns)
copy(newIndex.Include, index.Include)
return newIndex
}
func cloneView(view *models.View) *models.View {
if view == nil {
return nil
}
newView := &models.View{
Name: view.Name,
Description: view.Description,
Schema: view.Schema,
Definition: view.Definition,
Comment: view.Comment,
Sequence: view.Sequence,
Columns: make(map[string]*models.Column),
}
if view.Metadata != nil {
newView.Metadata = make(map[string]interface{})
for k, v := range view.Metadata {
newView.Metadata[k] = v
}
}
// Clone columns
for colName, col := range view.Columns {
newView.Columns[colName] = cloneColumn(col)
}
return newView
}
func cloneSequence(seq *models.Sequence) *models.Sequence {
if seq == nil {
return nil
}
newSeq := &models.Sequence{
Name: seq.Name,
Description: seq.Description,
Schema: seq.Schema,
StartValue: seq.StartValue,
MinValue: seq.MinValue,
MaxValue: seq.MaxValue,
IncrementBy: seq.IncrementBy,
CacheSize: seq.CacheSize,
Cycle: seq.Cycle,
OwnedByTable: seq.OwnedByTable,
OwnedByColumn: seq.OwnedByColumn,
Comment: seq.Comment,
Sequence: seq.Sequence,
}
return newSeq
}
func cloneEnum(enum *models.Enum) *models.Enum {
if enum == nil {
return nil
}
newEnum := &models.Enum{
Name: enum.Name,
Values: make([]string, len(enum.Values)),
Schema: enum.Schema,
}
copy(newEnum.Values, enum.Values)
return newEnum
}
func cloneRelation(rel *models.Relationship) *models.Relationship {
if rel == nil {
return nil
}
newRel := &models.Relationship{
Name: rel.Name,
Type: rel.Type,
FromTable: rel.FromTable,
FromSchema: rel.FromSchema,
FromColumns: make([]string, len(rel.FromColumns)),
ToTable: rel.ToTable,
ToSchema: rel.ToSchema,
ToColumns: make([]string, len(rel.ToColumns)),
ForeignKey: rel.ForeignKey,
ThroughTable: rel.ThroughTable,
ThroughSchema: rel.ThroughSchema,
Description: rel.Description,
Sequence: rel.Sequence,
}
if rel.Properties != nil {
newRel.Properties = make(map[string]string)
for k, v := range rel.Properties {
newRel.Properties[k] = v
}
}
copy(newRel.FromColumns, rel.FromColumns)
copy(newRel.ToColumns, rel.ToColumns)
return newRel
}
func cloneDomain(domain *models.Domain) *models.Domain {
if domain == nil {
return nil
}
newDomain := &models.Domain{
Name: domain.Name,
Description: domain.Description,
Comment: domain.Comment,
Sequence: domain.Sequence,
Tables: make([]*models.DomainTable, len(domain.Tables)),
}
if domain.Metadata != nil {
newDomain.Metadata = make(map[string]interface{})
for k, v := range domain.Metadata {
newDomain.Metadata[k] = v
}
}
copy(newDomain.Tables, domain.Tables)
return newDomain
}
// GetMergeSummary returns a human-readable summary of the merge result
func GetMergeSummary(result *MergeResult) string {
if result == nil {
return "No merge result available"
}
lines := []string{
"=== Merge Summary ===",
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
fmt.Sprintf("Tables added: %d", result.TablesAdded),
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
fmt.Sprintf("Views added: %d", result.ViewsAdded),
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
fmt.Sprintf("Relations added: %d", result.RelationsAdded),
fmt.Sprintf("Domains added: %d", result.DomainsAdded),
}
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
result.RelationsAdded + result.DomainsAdded
lines = append(lines, fmt.Sprintf("Total items added: %d", totalAdded))
summary := ""
for _, line := range lines {
summary += line + "\n"
}
return summary
}

View File

@@ -7,6 +7,8 @@ package models
import ( import (
"strings" "strings"
"time" "time"
"github.com/google/uuid"
) )
// DatabaseType represents the type of database system. // DatabaseType represents the type of database system.
@@ -30,6 +32,7 @@ type Database struct {
DatabaseVersion string `json:"database_version,omitempty" yaml:"database_version,omitempty" xml:"database_version,omitempty"` DatabaseVersion string `json:"database_version,omitempty" yaml:"database_version,omitempty" xml:"database_version,omitempty"`
SourceFormat string `json:"source_format,omitempty" yaml:"source_format,omitempty" xml:"source_format,omitempty"` // Source Format of the database. SourceFormat string `json:"source_format,omitempty" yaml:"source_format,omitempty" xml:"source_format,omitempty"` // Source Format of the database.
UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"` UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the database name in lowercase for SQL compatibility. // SQLName returns the database name in lowercase for SQL compatibility.
@@ -51,6 +54,7 @@ type Domain struct {
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"` Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"` Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the domain name in lowercase for SQL compatibility. // SQLName returns the domain name in lowercase for SQL compatibility.
@@ -66,6 +70,7 @@ type DomainTable struct {
SchemaName string `json:"schema_name" yaml:"schema_name" xml:"schema_name"` SchemaName string `json:"schema_name" yaml:"schema_name" xml:"schema_name"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefTable *Table `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references RefTable *Table `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// Schema represents a database schema, which is a logical grouping of database objects // Schema represents a database schema, which is a logical grouping of database objects
@@ -86,6 +91,7 @@ type Schema struct {
Relations []*Relationship `json:"relations,omitempty" yaml:"relations,omitempty" xml:"-"` Relations []*Relationship `json:"relations,omitempty" yaml:"relations,omitempty" xml:"-"`
Enums []*Enum `json:"enums,omitempty" yaml:"enums,omitempty" xml:"enums"` Enums []*Enum `json:"enums,omitempty" yaml:"enums,omitempty" xml:"enums"`
UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"` UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// UpdaUpdateDateted sets the UpdatedAt field to the current time in RFC3339 format. // UpdaUpdateDateted sets the UpdatedAt field to the current time in RFC3339 format.
@@ -117,6 +123,7 @@ type Table struct {
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"` UpdatedAt string `json:"updatedat,omitempty" yaml:"updatedat,omitempty" xml:"updatedat,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// UpdateDate sets the UpdatedAt field to the current time in RFC3339 format. // UpdateDate sets the UpdatedAt field to the current time in RFC3339 format.
@@ -165,6 +172,7 @@ type View struct {
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"` Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the view name in lowercase for SQL compatibility. // SQLName returns the view name in lowercase for SQL compatibility.
@@ -188,6 +196,7 @@ type Sequence struct {
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"` Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the sequence name in lowercase for SQL compatibility. // SQLName returns the sequence name in lowercase for SQL compatibility.
@@ -212,6 +221,7 @@ type Column struct {
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"` Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
Collation string `json:"collation,omitempty" yaml:"collation,omitempty" xml:"collation,omitempty"` Collation string `json:"collation,omitempty" yaml:"collation,omitempty" xml:"collation,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the column name in lowercase for SQL compatibility. // SQLName returns the column name in lowercase for SQL compatibility.
@@ -234,6 +244,7 @@ type Index struct {
Include []string `json:"include,omitempty" yaml:"include,omitempty" xml:"include,omitempty"` // INCLUDE columns Include []string `json:"include,omitempty" yaml:"include,omitempty" xml:"include,omitempty"` // INCLUDE columns
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"` Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the index name in lowercase for SQL compatibility. // SQLName returns the index name in lowercase for SQL compatibility.
@@ -268,6 +279,7 @@ type Relationship struct {
ThroughSchema string `json:"through_schema,omitempty" yaml:"through_schema,omitempty" xml:"through_schema,omitempty"` ThroughSchema string `json:"through_schema,omitempty" yaml:"through_schema,omitempty" xml:"through_schema,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty" xml:"description,omitempty"` Description string `json:"description,omitempty" yaml:"description,omitempty" xml:"description,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the relationship name in lowercase for SQL compatibility. // SQLName returns the relationship name in lowercase for SQL compatibility.
@@ -292,6 +304,7 @@ type Constraint struct {
Deferrable bool `json:"deferrable,omitempty" yaml:"deferrable,omitempty" xml:"deferrable,omitempty"` Deferrable bool `json:"deferrable,omitempty" yaml:"deferrable,omitempty" xml:"deferrable,omitempty"`
InitiallyDeferred bool `json:"initially_deferred,omitempty" yaml:"initially_deferred,omitempty" xml:"initially_deferred,omitempty"` InitiallyDeferred bool `json:"initially_deferred,omitempty" yaml:"initially_deferred,omitempty" xml:"initially_deferred,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the constraint name in lowercase for SQL compatibility. // SQLName returns the constraint name in lowercase for SQL compatibility.
@@ -307,6 +320,7 @@ type Enum struct {
Name string `json:"name" yaml:"name" xml:"name"` Name string `json:"name" yaml:"name" xml:"name"`
Values []string `json:"values" yaml:"values" xml:"values"` Values []string `json:"values" yaml:"values" xml:"values"`
Schema string `json:"schema,omitempty" yaml:"schema,omitempty" xml:"schema,omitempty"` Schema string `json:"schema,omitempty" yaml:"schema,omitempty" xml:"schema,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the enum name in lowercase for SQL compatibility. // SQLName returns the enum name in lowercase for SQL compatibility.
@@ -314,6 +328,16 @@ func (d *Enum) SQLName() string {
return strings.ToLower(d.Name) return strings.ToLower(d.Name)
} }
// InitEnum initializes a new Enum with empty values slice
func InitEnum(name, schema string) *Enum {
return &Enum{
Name: name,
Schema: schema,
Values: make([]string, 0),
GUID: uuid.New().String(),
}
}
// Supported constraint types. // Supported constraint types.
const ( const (
PrimaryKeyConstraint ConstraintType = "primary_key" // Primary key uniquely identifies each record PrimaryKeyConstraint ConstraintType = "primary_key" // Primary key uniquely identifies each record
@@ -335,6 +359,7 @@ type Script struct {
Version string `json:"version,omitempty" yaml:"version,omitempty" xml:"version,omitempty"` Version string `json:"version,omitempty" yaml:"version,omitempty" xml:"version,omitempty"`
Priority int `json:"priority,omitempty" yaml:"priority,omitempty" xml:"priority,omitempty"` Priority int `json:"priority,omitempty" yaml:"priority,omitempty" xml:"priority,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"` Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
GUID string `json:"guid" yaml:"guid" xml:"guid"`
} }
// SQLName returns the script name in lowercase for SQL compatibility. // SQLName returns the script name in lowercase for SQL compatibility.
@@ -350,6 +375,7 @@ func InitDatabase(name string) *Database {
Name: name, Name: name,
Schemas: make([]*Schema, 0), Schemas: make([]*Schema, 0),
Domains: make([]*Domain, 0), Domains: make([]*Domain, 0),
GUID: uuid.New().String(),
} }
} }
@@ -363,6 +389,7 @@ func InitSchema(name string) *Schema {
Permissions: make(map[string]string), Permissions: make(map[string]string),
Metadata: make(map[string]any), Metadata: make(map[string]any),
Scripts: make([]*Script, 0), Scripts: make([]*Script, 0),
GUID: uuid.New().String(),
} }
} }
@@ -376,6 +403,7 @@ func InitTable(name, schema string) *Table {
Indexes: make(map[string]*Index), Indexes: make(map[string]*Index),
Relationships: make(map[string]*Relationship), Relationships: make(map[string]*Relationship),
Metadata: make(map[string]any), Metadata: make(map[string]any),
GUID: uuid.New().String(),
} }
} }
@@ -385,6 +413,7 @@ func InitColumn(name, table, schema string) *Column {
Name: name, Name: name,
Table: table, Table: table,
Schema: schema, Schema: schema,
GUID: uuid.New().String(),
} }
} }
@@ -396,6 +425,7 @@ func InitIndex(name, table, schema string) *Index {
Schema: schema, Schema: schema,
Columns: make([]string, 0), Columns: make([]string, 0),
Include: make([]string, 0), Include: make([]string, 0),
GUID: uuid.New().String(),
} }
} }
@@ -408,6 +438,7 @@ func InitRelation(name, schema string) *Relationship {
Properties: make(map[string]string), Properties: make(map[string]string),
FromColumns: make([]string, 0), FromColumns: make([]string, 0),
ToColumns: make([]string, 0), ToColumns: make([]string, 0),
GUID: uuid.New().String(),
} }
} }
@@ -417,6 +448,7 @@ func InitRelationship(name string, relType RelationType) *Relationship {
Name: name, Name: name,
Type: relType, Type: relType,
Properties: make(map[string]string), Properties: make(map[string]string),
GUID: uuid.New().String(),
} }
} }
@@ -427,6 +459,7 @@ func InitConstraint(name string, constraintType ConstraintType) *Constraint {
Type: constraintType, Type: constraintType,
Columns: make([]string, 0), Columns: make([]string, 0),
ReferencedColumns: make([]string, 0), ReferencedColumns: make([]string, 0),
GUID: uuid.New().String(),
} }
} }
@@ -435,6 +468,7 @@ func InitScript(name string) *Script {
return &Script{ return &Script{
Name: name, Name: name,
RunAfter: make([]string, 0), RunAfter: make([]string, 0),
GUID: uuid.New().String(),
} }
} }
@@ -445,6 +479,7 @@ func InitView(name, schema string) *View {
Schema: schema, Schema: schema,
Columns: make(map[string]*Column), Columns: make(map[string]*Column),
Metadata: make(map[string]any), Metadata: make(map[string]any),
GUID: uuid.New().String(),
} }
} }
@@ -455,6 +490,7 @@ func InitSequence(name, schema string) *Sequence {
Schema: schema, Schema: schema,
IncrementBy: 1, IncrementBy: 1,
StartValue: 1, StartValue: 1,
GUID: uuid.New().String(),
} }
} }
@@ -464,6 +500,7 @@ func InitDomain(name string) *Domain {
Name: name, Name: name,
Tables: make([]*DomainTable, 0), Tables: make([]*DomainTable, 0),
Metadata: make(map[string]any), Metadata: make(map[string]any),
GUID: uuid.New().String(),
} }
} }
@@ -472,5 +509,6 @@ func InitDomainTable(tableName, schemaName string) *DomainTable {
return &DomainTable{ return &DomainTable{
TableName: tableName, TableName: tableName,
SchemaName: schemaName, SchemaName: schemaName,
GUID: uuid.New().String(),
} }
} }

View File

@@ -632,6 +632,9 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
column.Name = parts[0] column.Name = parts[0]
} }
// Track if we found explicit nullability markers
hasExplicitNullableMarker := false
// Parse tag attributes // Parse tag attributes
for _, part := range parts[1:] { for _, part := range parts[1:] {
kv := strings.SplitN(part, ":", 2) kv := strings.SplitN(part, ":", 2)
@@ -649,6 +652,10 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
column.IsPrimaryKey = true column.IsPrimaryKey = true
case "notnull": case "notnull":
column.NotNull = true column.NotNull = true
hasExplicitNullableMarker = true
case "nullzero":
column.NotNull = false
hasExplicitNullableMarker = true
case "autoincrement": case "autoincrement":
column.AutoIncrement = true column.AutoIncrement = true
case "default": case "default":
@@ -664,17 +671,15 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
// Determine if nullable based on Go type and bun tags // Determine if nullable based on Go type and bun tags
// In Bun: // In Bun:
// - nullzero tag means the field is nullable (can be NULL in DB) // - explicit "notnull" tag means NOT NULL
// - absence of nullzero means the field is NOT NULL // - explicit "nullzero" tag means nullable
// - primitive types (int64, bool, string) are NOT NULL by default // - absence of explicit markers: infer from Go type
column.NotNull = true if !hasExplicitNullableMarker {
// Primary keys are always NOT NULL // Infer from Go type if no explicit marker found
if strings.Contains(bunTag, "nullzero") {
column.NotNull = false
} else {
column.NotNull = !r.isNullableGoType(fieldType) column.NotNull = !r.isNullableGoType(fieldType)
} }
// Primary keys are always NOT NULL
if column.IsPrimaryKey { if column.IsPrimaryKey {
column.NotNull = true column.NotNull = true
} }

View File

@@ -4,7 +4,9 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"os" "os"
"path/filepath"
"regexp" "regexp"
"sort"
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
@@ -24,11 +26,23 @@ func NewReader(options *readers.ReaderOptions) *Reader {
} }
// ReadDatabase reads and parses DBML input, returning a Database model // ReadDatabase reads and parses DBML input, returning a Database model
// If FilePath points to a directory, all .dbml files are loaded and merged
func (r *Reader) ReadDatabase() (*models.Database, error) { func (r *Reader) ReadDatabase() (*models.Database, error) {
if r.options.FilePath == "" { if r.options.FilePath == "" {
return nil, fmt.Errorf("file path is required for DBML reader") return nil, fmt.Errorf("file path is required for DBML reader")
} }
// Check if path is a directory
info, err := os.Stat(r.options.FilePath)
if err != nil {
return nil, fmt.Errorf("failed to stat path: %w", err)
}
if info.IsDir() {
return r.readDirectoryDBML(r.options.FilePath)
}
// Single file - existing logic
content, err := os.ReadFile(r.options.FilePath) content, err := os.ReadFile(r.options.FilePath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err) return nil, fmt.Errorf("failed to read file: %w", err)
@@ -67,15 +81,301 @@ func (r *Reader) ReadTable() (*models.Table, error) {
return schema.Tables[0], nil return schema.Tables[0], nil
} }
// stripQuotes removes surrounding quotes from an identifier // readDirectoryDBML processes all .dbml files in directory
// Returns merged Database model
func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
// Discover and sort DBML files
files, err := r.discoverDBMLFiles(dirPath)
if err != nil {
return nil, fmt.Errorf("failed to discover DBML files: %w", err)
}
// If no files found, return empty database
if len(files) == 0 {
db := models.InitDatabase("database")
if r.options.Metadata != nil {
if name, ok := r.options.Metadata["name"].(string); ok {
db.Name = name
}
}
return db, nil
}
// Initialize database (will be merged with files)
var db *models.Database
// Process each file in sorted order
for _, filePath := range files {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %w", filePath, err)
}
fileDB, err := r.parseDBML(string(content))
if err != nil {
return nil, fmt.Errorf("failed to parse file %s: %w", filePath, err)
}
// First file initializes the database
if db == nil {
db = fileDB
} else {
// Subsequent files are merged
mergeDatabase(db, fileDB)
}
}
return db, nil
}
// stripQuotes removes surrounding quotes and comments from an identifier
func stripQuotes(s string) string { func stripQuotes(s string) string {
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
// Remove DBML comments in brackets (e.g., [note: 'description'])
// This handles inline comments like: "table_name" [note: 'comment']
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
s = commentRegex.ReplaceAllString(s, "")
// Trim again after removing comments
s = strings.TrimSpace(s)
// Remove surrounding quotes (double or single)
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) { if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
return s[1 : len(s)-1] return s[1 : len(s)-1]
} }
return s return s
} }
// parseFilePrefix extracts numeric prefix from filename
// Examples: "1_schema.dbml" -> (1, true), "tables.dbml" -> (0, false)
func parseFilePrefix(filename string) (int, bool) {
base := filepath.Base(filename)
re := regexp.MustCompile(`^(\d+)[_-]`)
matches := re.FindStringSubmatch(base)
if len(matches) > 1 {
var prefix int
_, err := fmt.Sscanf(matches[1], "%d", &prefix)
if err == nil {
return prefix, true
}
}
return 0, false
}
// hasCommentedRefs scans file content for commented-out Ref statements
// Returns true if file contains lines like: // Ref: table.col > other.col
func hasCommentedRefs(filePath string) (bool, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return false, err
}
scanner := bufio.NewScanner(strings.NewReader(string(content)))
commentedRefRegex := regexp.MustCompile(`^\s*//.*Ref:\s+`)
for scanner.Scan() {
line := scanner.Text()
if commentedRefRegex.MatchString(line) {
return true, nil
}
}
return false, nil
}
// discoverDBMLFiles finds all .dbml files in directory and returns them sorted
func (r *Reader) discoverDBMLFiles(dirPath string) ([]string, error) {
pattern := filepath.Join(dirPath, "*.dbml")
files, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob .dbml files: %w", err)
}
return sortDBMLFiles(files), nil
}
// sortDBMLFiles sorts files by:
// 1. Files without commented refs (by numeric prefix, then alphabetically)
// 2. Files with commented refs (by numeric prefix, then alphabetically)
func sortDBMLFiles(files []string) []string {
// Create a slice to hold file info for sorting
type fileInfo struct {
path string
hasCommented bool
prefix int
hasPrefix bool
basename string
}
fileInfos := make([]fileInfo, 0, len(files))
for _, file := range files {
hasCommented, err := hasCommentedRefs(file)
if err != nil {
// If we can't read the file, treat it as not having commented refs
hasCommented = false
}
prefix, hasPrefix := parseFilePrefix(file)
basename := filepath.Base(file)
fileInfos = append(fileInfos, fileInfo{
path: file,
hasCommented: hasCommented,
prefix: prefix,
hasPrefix: hasPrefix,
basename: basename,
})
}
// Sort by: hasCommented (false first), hasPrefix (true first), prefix, basename
sort.Slice(fileInfos, func(i, j int) bool {
// First, sort by commented refs (files without commented refs come first)
if fileInfos[i].hasCommented != fileInfos[j].hasCommented {
return !fileInfos[i].hasCommented
}
// Then by presence of prefix (files with prefix come first)
if fileInfos[i].hasPrefix != fileInfos[j].hasPrefix {
return fileInfos[i].hasPrefix
}
// If both have prefix, sort by prefix value
if fileInfos[i].hasPrefix && fileInfos[j].hasPrefix {
if fileInfos[i].prefix != fileInfos[j].prefix {
return fileInfos[i].prefix < fileInfos[j].prefix
}
}
// Finally, sort alphabetically by basename
return fileInfos[i].basename < fileInfos[j].basename
})
// Extract sorted paths
sortedFiles := make([]string, len(fileInfos))
for i, info := range fileInfos {
sortedFiles[i] = info.path
}
return sortedFiles
}
// mergeTable combines two table definitions
// Merges: Columns (map), Constraints (map), Indexes (map), Relationships (map)
// Uses first non-empty Description
func mergeTable(baseTable, fileTable *models.Table) {
// Merge columns (map naturally merges - later keys overwrite)
for key, col := range fileTable.Columns {
baseTable.Columns[key] = col
}
// Merge constraints
for key, constraint := range fileTable.Constraints {
baseTable.Constraints[key] = constraint
}
// Merge indexes
for key, index := range fileTable.Indexes {
baseTable.Indexes[key] = index
}
// Merge relationships
for key, rel := range fileTable.Relationships {
baseTable.Relationships[key] = rel
}
// Use first non-empty description
if baseTable.Description == "" && fileTable.Description != "" {
baseTable.Description = fileTable.Description
}
// Merge metadata maps
if baseTable.Metadata == nil {
baseTable.Metadata = make(map[string]any)
}
for key, val := range fileTable.Metadata {
baseTable.Metadata[key] = val
}
}
// mergeSchema finds or creates schema and merges tables
func mergeSchema(baseDB *models.Database, fileSchema *models.Schema) {
// Find existing schema by name (normalize names by stripping quotes)
var existingSchema *models.Schema
fileSchemaName := stripQuotes(fileSchema.Name)
for _, schema := range baseDB.Schemas {
if stripQuotes(schema.Name) == fileSchemaName {
existingSchema = schema
break
}
}
// If schema doesn't exist, add it and return
if existingSchema == nil {
baseDB.Schemas = append(baseDB.Schemas, fileSchema)
return
}
// Merge tables from fileSchema into existingSchema
for _, fileTable := range fileSchema.Tables {
// Find existing table by name (normalize names by stripping quotes)
var existingTable *models.Table
fileTableName := stripQuotes(fileTable.Name)
for _, table := range existingSchema.Tables {
if stripQuotes(table.Name) == fileTableName {
existingTable = table
break
}
}
// If table doesn't exist, add it
if existingTable == nil {
existingSchema.Tables = append(existingSchema.Tables, fileTable)
} else {
// Merge table properties - tables are identical, skip
mergeTable(existingTable, fileTable)
}
}
// Merge other schema properties
existingSchema.Views = append(existingSchema.Views, fileSchema.Views...)
existingSchema.Sequences = append(existingSchema.Sequences, fileSchema.Sequences...)
existingSchema.Scripts = append(existingSchema.Scripts, fileSchema.Scripts...)
// Merge permissions
if existingSchema.Permissions == nil {
existingSchema.Permissions = make(map[string]string)
}
for key, val := range fileSchema.Permissions {
existingSchema.Permissions[key] = val
}
// Merge metadata
if existingSchema.Metadata == nil {
existingSchema.Metadata = make(map[string]any)
}
for key, val := range fileSchema.Metadata {
existingSchema.Metadata[key] = val
}
}
// mergeDatabase merges schemas from fileDB into baseDB
func mergeDatabase(baseDB, fileDB *models.Database) {
// Merge each schema from fileDB
for _, fileSchema := range fileDB.Schemas {
mergeSchema(baseDB, fileSchema)
}
// Merge domains
baseDB.Domains = append(baseDB.Domains, fileDB.Domains...)
// Use first non-empty description
if baseDB.Description == "" && fileDB.Description != "" {
baseDB.Description = fileDB.Description
}
}
// parseDBML parses DBML content and returns a Database model // parseDBML parses DBML content and returns a Database model
func (r *Reader) parseDBML(content string) (*models.Database, error) { func (r *Reader) parseDBML(content string) (*models.Database, error) {
db := models.InitDatabase("database") db := models.InitDatabase("database")
@@ -332,29 +632,33 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Format: (columns) [attributes] OR columnname [attributes] // Format: (columns) [attributes] OR columnname [attributes]
var columns []string var columns []string
if strings.Contains(line, "(") && strings.Contains(line, ")") { // Find the attributes section to avoid parsing parentheses in notes/attributes
attrStart := strings.Index(line, "[")
columnPart := line
if attrStart > 0 {
columnPart = line[:attrStart]
}
if strings.Contains(columnPart, "(") && strings.Contains(columnPart, ")") {
// Multi-column format: (col1, col2) [attributes] // Multi-column format: (col1, col2) [attributes]
colStart := strings.Index(line, "(") colStart := strings.Index(columnPart, "(")
colEnd := strings.Index(line, ")") colEnd := strings.Index(columnPart, ")")
if colStart >= colEnd { if colStart >= colEnd {
return nil return nil
} }
columnsStr := line[colStart+1 : colEnd] columnsStr := columnPart[colStart+1 : colEnd]
for _, col := range strings.Split(columnsStr, ",") { for _, col := range strings.Split(columnsStr, ",") {
columns = append(columns, stripQuotes(strings.TrimSpace(col))) columns = append(columns, stripQuotes(strings.TrimSpace(col)))
} }
} else if strings.Contains(line, "[") { } else if attrStart > 0 {
// Single column format: columnname [attributes] // Single column format: columnname [attributes]
// Extract column name before the bracket // Extract column name before the bracket
idx := strings.Index(line, "[") colName := strings.TrimSpace(columnPart)
if idx > 0 {
colName := strings.TrimSpace(line[:idx])
if colName != "" { if colName != "" {
columns = []string{stripQuotes(colName)} columns = []string{stripQuotes(colName)}
} }
} }
}
if len(columns) == 0 { if len(columns) == 0 {
return nil return nil

View File

@@ -1,6 +1,7 @@
package dbml package dbml
import ( import (
"os"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -517,3 +518,286 @@ func TestGetForeignKeys(t *testing.T) {
t.Error("Expected foreign key constraint type") t.Error("Expected foreign key constraint type")
} }
} }
// Tests for multi-file directory loading
func TestReadDirectory_MultipleFiles(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
if db == nil {
t.Fatal("ReadDatabase() returned nil database")
}
// Should have public schema
if len(db.Schemas) == 0 {
t.Fatal("Expected at least one schema")
}
var publicSchema *models.Schema
for _, schema := range db.Schemas {
if schema.Name == "public" {
publicSchema = schema
break
}
}
if publicSchema == nil {
t.Fatal("Public schema not found")
}
// Should have 3 tables: users, posts, comments
if len(publicSchema.Tables) != 3 {
t.Fatalf("Expected 3 tables, got %d", len(publicSchema.Tables))
}
// Find tables
var usersTable, postsTable, commentsTable *models.Table
for _, table := range publicSchema.Tables {
switch table.Name {
case "users":
usersTable = table
case "posts":
postsTable = table
case "comments":
commentsTable = table
}
}
if usersTable == nil {
t.Fatal("Users table not found")
}
if postsTable == nil {
t.Fatal("Posts table not found")
}
if commentsTable == nil {
t.Fatal("Comments table not found")
}
// Verify users table has merged columns from 1_users.dbml and 3_add_columns.dbml
expectedUserColumns := []string{"id", "email", "name", "created_at"}
if len(usersTable.Columns) != len(expectedUserColumns) {
t.Errorf("Expected %d columns in users table, got %d", len(expectedUserColumns), len(usersTable.Columns))
}
for _, colName := range expectedUserColumns {
if _, exists := usersTable.Columns[colName]; !exists {
t.Errorf("Expected column '%s' in users table", colName)
}
}
// Verify posts table columns
expectedPostColumns := []string{"id", "user_id", "title", "content", "created_at"}
for _, colName := range expectedPostColumns {
if _, exists := postsTable.Columns[colName]; !exists {
t.Errorf("Expected column '%s' in posts table", colName)
}
}
}
func TestReadDirectory_TableMerging(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
// Find users table
var usersTable *models.Table
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
if table.Name == "users" && schema.Name == "public" {
usersTable = table
break
}
}
}
if usersTable == nil {
t.Fatal("Users table not found")
}
// Verify columns from file 1 (id, email)
if _, exists := usersTable.Columns["id"]; !exists {
t.Error("Column 'id' from 1_users.dbml not found")
}
if _, exists := usersTable.Columns["email"]; !exists {
t.Error("Column 'email' from 1_users.dbml not found")
}
// Verify columns from file 3 (name, created_at)
if _, exists := usersTable.Columns["name"]; !exists {
t.Error("Column 'name' from 3_add_columns.dbml not found")
}
if _, exists := usersTable.Columns["created_at"]; !exists {
t.Error("Column 'created_at' from 3_add_columns.dbml not found")
}
// Verify column properties from file 1
emailCol := usersTable.Columns["email"]
if !emailCol.NotNull {
t.Error("Email column should be not null (from 1_users.dbml)")
}
if emailCol.Type != "varchar(255)" {
t.Errorf("Expected email type 'varchar(255)', got '%s'", emailCol.Type)
}
}
func TestReadDirectory_CommentedRefsLast(t *testing.T) {
// This test verifies that files with commented refs are processed last
// by checking that the file discovery returns them in the correct order
dirPath := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile")
opts := &readers.ReaderOptions{
FilePath: dirPath,
}
reader := NewReader(opts)
files, err := reader.discoverDBMLFiles(dirPath)
if err != nil {
t.Fatalf("discoverDBMLFiles() error = %v", err)
}
if len(files) < 2 {
t.Skip("Not enough files to test ordering")
}
// Check that 9_refs.dbml (which has commented refs) comes last
lastFile := filepath.Base(files[len(files)-1])
if lastFile != "9_refs.dbml" {
t.Errorf("Expected last file to be '9_refs.dbml' (has commented refs), got '%s'", lastFile)
}
// Check that numbered files without commented refs come first
firstFile := filepath.Base(files[0])
if firstFile != "1_users.dbml" {
t.Errorf("Expected first file to be '1_users.dbml', got '%s'", firstFile)
}
}
func TestReadDirectory_EmptyDirectory(t *testing.T) {
// Create a temporary empty directory
tmpDir := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "empty_test_dir")
err := os.MkdirAll(tmpDir, 0755)
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
opts := &readers.ReaderOptions{
FilePath: tmpDir,
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() should not error on empty directory, got: %v", err)
}
if db == nil {
t.Fatal("ReadDatabase() returned nil database")
}
// Empty directory should return empty database
if len(db.Schemas) != 0 {
t.Errorf("Expected 0 schemas for empty directory, got %d", len(db.Schemas))
}
}
func TestReadDatabase_BackwardCompat(t *testing.T) {
// Test that single file loading still works
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "simple.dbml"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
if db == nil {
t.Fatal("ReadDatabase() returned nil database")
}
if len(db.Schemas) == 0 {
t.Fatal("Expected at least one schema")
}
schema := db.Schemas[0]
if len(schema.Tables) != 1 {
t.Fatalf("Expected 1 table, got %d", len(schema.Tables))
}
table := schema.Tables[0]
if table.Name != "users" {
t.Errorf("Expected table name 'users', got '%s'", table.Name)
}
}
func TestParseFilePrefix(t *testing.T) {
tests := []struct {
filename string
wantPrefix int
wantHas bool
}{
{"1_schema.dbml", 1, true},
{"2_tables.dbml", 2, true},
{"10_relationships.dbml", 10, true},
{"99_data.dbml", 99, true},
{"schema.dbml", 0, false},
{"tables_no_prefix.dbml", 0, false},
{"/path/to/1_file.dbml", 1, true},
{"/path/to/file.dbml", 0, false},
{"1-file.dbml", 1, true},
{"2-another.dbml", 2, true},
}
for _, tt := range tests {
t.Run(tt.filename, func(t *testing.T) {
gotPrefix, gotHas := parseFilePrefix(tt.filename)
if gotPrefix != tt.wantPrefix {
t.Errorf("parseFilePrefix(%s) prefix = %d, want %d", tt.filename, gotPrefix, tt.wantPrefix)
}
if gotHas != tt.wantHas {
t.Errorf("parseFilePrefix(%s) hasPrefix = %v, want %v", tt.filename, gotHas, tt.wantHas)
}
})
}
}
func TestHasCommentedRefs(t *testing.T) {
// Test with the actual multifile test fixtures
tests := []struct {
filename string
wantHas bool
}{
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "1_users.dbml"), false},
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "2_posts.dbml"), false},
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "3_add_columns.dbml"), false},
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "9_refs.dbml"), true},
}
for _, tt := range tests {
t.Run(filepath.Base(tt.filename), func(t *testing.T) {
gotHas, err := hasCommentedRefs(tt.filename)
if err != nil {
t.Fatalf("hasCommentedRefs() error = %v", err)
}
if gotHas != tt.wantHas {
t.Errorf("hasCommentedRefs(%s) = %v, want %v", filepath.Base(tt.filename), gotHas, tt.wantHas)
}
})
}
}

View File

@@ -79,6 +79,8 @@ func (r *Reader) convertToDatabase(dctx *models.DCTXDictionary) (*models.Databas
db := models.InitDatabase(dbName) db := models.InitDatabase(dbName)
schema := models.InitSchema("public") schema := models.InitSchema("public")
// Note: DCTX doesn't have database GUID, but schema can use dictionary name if available
// Create GUID mappings for tables and keys // Create GUID mappings for tables and keys
tableGuidMap := make(map[string]string) // GUID -> table name tableGuidMap := make(map[string]string) // GUID -> table name
keyGuidMap := make(map[string]*models.DCTXKey) // GUID -> key definition keyGuidMap := make(map[string]*models.DCTXKey) // GUID -> key definition
@@ -162,6 +164,10 @@ func (r *Reader) convertTable(dctxTable *models.DCTXTable) (*models.Table, map[s
tableName := r.sanitizeName(dctxTable.Name) tableName := r.sanitizeName(dctxTable.Name)
table := models.InitTable(tableName, "public") table := models.InitTable(tableName, "public")
table.Description = dctxTable.Description table.Description = dctxTable.Description
// Assign GUID from DCTX table
if dctxTable.Guid != "" {
table.GUID = dctxTable.Guid
}
fieldGuidMap := make(map[string]string) fieldGuidMap := make(map[string]string)
@@ -202,6 +208,10 @@ func (r *Reader) convertField(dctxField *models.DCTXField, tableName string) ([]
// Convert single field // Convert single field
column := models.InitColumn(r.sanitizeName(dctxField.Name), tableName, "public") column := models.InitColumn(r.sanitizeName(dctxField.Name), tableName, "public")
// Assign GUID from DCTX field
if dctxField.Guid != "" {
column.GUID = dctxField.Guid
}
// Map Clarion data types // Map Clarion data types
dataType, length := r.mapDataType(dctxField.DataType, dctxField.Size) dataType, length := r.mapDataType(dctxField.DataType, dctxField.Size)
@@ -346,6 +356,10 @@ func (r *Reader) convertKey(dctxKey *models.DCTXKey, table *models.Table, fieldG
constraint.Table = table.Name constraint.Table = table.Name
constraint.Schema = table.Schema constraint.Schema = table.Schema
constraint.Columns = columns constraint.Columns = columns
// Assign GUID from DCTX key
if dctxKey.Guid != "" {
constraint.GUID = dctxKey.Guid
}
table.Constraints[constraint.Name] = constraint table.Constraints[constraint.Name] = constraint
@@ -366,6 +380,10 @@ func (r *Reader) convertKey(dctxKey *models.DCTXKey, table *models.Table, fieldG
index.Columns = columns index.Columns = columns
index.Unique = dctxKey.Unique index.Unique = dctxKey.Unique
index.Type = "btree" index.Type = "btree"
// Assign GUID from DCTX key
if dctxKey.Guid != "" {
index.GUID = dctxKey.Guid
}
table.Indexes[index.Name] = index table.Indexes[index.Name] = index
return nil return nil
@@ -460,6 +478,10 @@ func (r *Reader) processRelations(dctx *models.DCTXDictionary, schema *models.Sc
constraint.ReferencedColumns = pkColumns constraint.ReferencedColumns = pkColumns
constraint.OnDelete = r.mapReferentialAction(relation.Delete) constraint.OnDelete = r.mapReferentialAction(relation.Delete)
constraint.OnUpdate = r.mapReferentialAction(relation.Update) constraint.OnUpdate = r.mapReferentialAction(relation.Update)
// Assign GUID from DCTX relation
if relation.Guid != "" {
constraint.GUID = relation.Guid
}
foreignTable.Constraints[fkName] = constraint foreignTable.Constraints[fkName] = constraint
@@ -473,6 +495,10 @@ func (r *Reader) processRelations(dctx *models.DCTXDictionary, schema *models.Sc
relationship.ForeignKey = fkName relationship.ForeignKey = fkName
relationship.Properties["on_delete"] = constraint.OnDelete relationship.Properties["on_delete"] = constraint.OnDelete
relationship.Properties["on_update"] = constraint.OnUpdate relationship.Properties["on_update"] = constraint.OnUpdate
// Assign GUID from DCTX relation
if relation.Guid != "" {
relationship.GUID = relation.Guid
}
foreignTable.Relationships[relationshipName] = relationship foreignTable.Relationships[relationshipName] = relationship
} }

View File

@@ -241,11 +241,9 @@ func (r *Reader) parsePgEnum(line string, matches []string) *models.Enum {
} }
} }
return &models.Enum{ enum := models.InitEnum(enumName, "public")
Name: enumName, enum.Values = values
Values: values, return enum
Schema: "public",
}
} }
// parseTableBlock parses a complete pgTable definition block // parseTableBlock parses a complete pgTable definition block

View File

@@ -260,11 +260,7 @@ func (r *Reader) parseType(typeName string, lines []string, schema *models.Schem
} }
func (r *Reader) parseEnum(enumName string, lines []string, schema *models.Schema) { func (r *Reader) parseEnum(enumName string, lines []string, schema *models.Schema) {
enum := &models.Enum{ enum := models.InitEnum(enumName, schema.Name)
Name: enumName,
Schema: schema.Name,
Values: make([]string, 0),
}
for _, line := range lines { for _, line := range lines {
trimmed := strings.TrimSpace(line) trimmed := strings.TrimSpace(line)

View File

@@ -128,11 +128,7 @@ func (r *Reader) parsePrisma(content string) (*models.Database, error) {
if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil { if matches := enumRegex.FindStringSubmatch(trimmed); matches != nil {
currentBlock = "enum" currentBlock = "enum"
enumName := matches[1] enumName := matches[1]
currentEnum = &models.Enum{ currentEnum = models.InitEnum(enumName, "public")
Name: enumName,
Schema: "public",
Values: make([]string, 0),
}
blockContent = []string{} blockContent = []string{}
continue continue
} }

View File

@@ -150,13 +150,11 @@ func (r *Reader) readScripts() ([]*models.Script, error) {
} }
// Create Script model // Create Script model
script := &models.Script{ script := models.InitScript(name)
Name: name, script.Description = fmt.Sprintf("SQL script from %s", relPath)
Description: fmt.Sprintf("SQL script from %s", relPath), script.SQL = string(content)
SQL: string(content), script.Priority = priority
Priority: priority, script.Sequence = uint(sequence)
Sequence: uint(sequence),
}
scripts = append(scripts, script) scripts = append(scripts, script)

View File

@@ -23,6 +23,7 @@ func (se *SchemaEditor) showColumnEditor(schemaIndex, tableIndex, colIndex int,
newIsNotNull := column.NotNull newIsNotNull := column.NotNull
newDefault := column.Default newDefault := column.Default
newDescription := column.Description newDescription := column.Description
newGUID := column.GUID
// Column type options: PostgreSQL, MySQL, SQL Server, and common SQL types // Column type options: PostgreSQL, MySQL, SQL Server, and common SQL types
columnTypes := []string{ columnTypes := []string{
@@ -94,9 +95,14 @@ func (se *SchemaEditor) showColumnEditor(schemaIndex, tableIndex, colIndex int,
newDescription = value newDescription = value
}) })
form.AddInputField("GUID", column.GUID, 40, nil, func(value string) {
newGUID = value
})
form.AddButton("Save", func() { form.AddButton("Save", func() {
// Apply changes using dataops // Apply changes using dataops
se.UpdateColumn(schemaIndex, tableIndex, originalName, newName, newType, newIsPK, newIsNotNull, newDefault, newDescription) se.UpdateColumn(schemaIndex, tableIndex, originalName, newName, newType, newIsPK, newIsNotNull, newDefault, newDescription)
se.db.Schemas[schemaIndex].Tables[tableIndex].Columns[newName].GUID = newGUID
se.pages.RemovePage("column-editor") se.pages.RemovePage("column-editor")
se.pages.SwitchToPage("table-editor") se.pages.SwitchToPage("table-editor")

View File

@@ -14,6 +14,7 @@ func (se *SchemaEditor) showEditDatabaseForm() {
dbComment := se.db.Comment dbComment := se.db.Comment
dbType := string(se.db.DatabaseType) dbType := string(se.db.DatabaseType)
dbVersion := se.db.DatabaseVersion dbVersion := se.db.DatabaseVersion
dbGUID := se.db.GUID
// Database type options // Database type options
dbTypeOptions := []string{"pgsql", "mssql", "sqlite"} dbTypeOptions := []string{"pgsql", "mssql", "sqlite"}
@@ -45,11 +46,16 @@ func (se *SchemaEditor) showEditDatabaseForm() {
dbVersion = value dbVersion = value
}) })
form.AddInputField("GUID", dbGUID, 40, nil, func(value string) {
dbGUID = value
})
form.AddButton("Save", func() { form.AddButton("Save", func() {
if dbName == "" { if dbName == "" {
return return
} }
se.updateDatabase(dbName, dbDescription, dbComment, dbType, dbVersion) se.updateDatabase(dbName, dbDescription, dbComment, dbType, dbVersion)
se.db.GUID = dbGUID
se.pages.RemovePage("edit-database") se.pages.RemovePage("edit-database")
se.pages.RemovePage("main") se.pages.RemovePage("main")
se.pages.AddPage("main", se.createMainMenu(), true, true) se.pages.AddPage("main", se.createMainMenu(), true, true)

View File

@@ -4,10 +4,12 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/gdamore/tcell/v2" "github.com/gdamore/tcell/v2"
"github.com/rivo/tview" "github.com/rivo/tview"
"git.warky.dev/wdevs/relspecgo/pkg/merge"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers"
rbun "git.warky.dev/wdevs/relspecgo/pkg/readers/bun" rbun "git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
@@ -107,8 +109,7 @@ func (se *SchemaEditor) showLoadScreen() {
// Keyboard shortcuts // Keyboard shortcuts
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
switch event.Key() { if event.Key() == tcell.KeyEscape {
case tcell.KeyEscape:
se.app.Stop() se.app.Stop()
return nil return nil
} }
@@ -214,8 +215,7 @@ func (se *SchemaEditor) showSaveScreen() {
// Keyboard shortcuts // Keyboard shortcuts
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
switch event.Key() { if event.Key() == tcell.KeyEscape {
case tcell.KeyEscape:
se.pages.RemovePage("save-database") se.pages.RemovePage("save-database")
se.pages.SwitchToPage("main") se.pages.SwitchToPage("main")
return nil return nil
@@ -524,3 +524,268 @@ Examples:
- File: ~/schemas/mydb.dbml - File: ~/schemas/mydb.dbml
- Directory (for code formats): ./models/` - Directory (for code formats): ./models/`
} }
// showImportScreen displays the import/merge database screen
func (se *SchemaEditor) showImportScreen() {
flex := tview.NewFlex().SetDirection(tview.FlexRow)
// Title
title := tview.NewTextView().
SetText("[::b]Import & Merge Database Schema").
SetTextAlign(tview.AlignCenter).
SetDynamicColors(true)
// Form
form := tview.NewForm()
form.SetBorder(true).SetTitle(" Import Configuration ").SetTitleAlign(tview.AlignLeft)
// Format selection
formatOptions := []string{
"dbml", "dctx", "drawdb", "graphql", "json", "yaml",
"gorm", "bun", "drizzle", "prisma", "typeorm", "pgsql",
}
selectedFormat := 0
currentFormat := formatOptions[selectedFormat]
// File path input
filePath := ""
connString := ""
skipDomains := false
skipRelations := false
skipEnums := false
skipViews := false
skipSequences := false
skipTables := ""
form.AddDropDown("Format", formatOptions, 0, func(option string, index int) {
selectedFormat = index
currentFormat = option
})
form.AddInputField("File Path", "", 50, nil, func(value string) {
filePath = value
})
form.AddInputField("Connection String", "", 50, nil, func(value string) {
connString = value
})
form.AddInputField("Skip Tables (comma-separated)", "", 50, nil, func(value string) {
skipTables = value
})
form.AddCheckbox("Skip Domains", false, func(checked bool) {
skipDomains = checked
})
form.AddCheckbox("Skip Relations", false, func(checked bool) {
skipRelations = checked
})
form.AddCheckbox("Skip Enums", false, func(checked bool) {
skipEnums = checked
})
form.AddCheckbox("Skip Views", false, func(checked bool) {
skipViews = checked
})
form.AddCheckbox("Skip Sequences", false, func(checked bool) {
skipSequences = checked
})
form.AddTextView("Help", getImportHelpText(), 0, 7, true, false)
// Buttons
form.AddButton("Import & Merge [i]", func() {
se.importAndMergeDatabase(currentFormat, filePath, connString, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
})
form.AddButton("Back [b]", func() {
se.pages.RemovePage("import-database")
se.pages.SwitchToPage("main")
})
form.AddButton("Exit [q]", func() {
se.app.Stop()
})
// Keyboard shortcuts
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyEscape {
se.pages.RemovePage("import-database")
se.pages.SwitchToPage("main")
return nil
}
switch event.Rune() {
case 'i':
se.importAndMergeDatabase(currentFormat, filePath, connString, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
return nil
case 'b':
se.pages.RemovePage("import-database")
se.pages.SwitchToPage("main")
return nil
case 'q':
se.app.Stop()
return nil
}
return event
})
flex.AddItem(title, 1, 0, false).
AddItem(form, 0, 1, true)
se.pages.AddAndSwitchToPage("import-database", flex, true)
}
// importAndMergeDatabase imports and merges a database from the specified configuration
func (se *SchemaEditor) importAndMergeDatabase(format, filePath, connString string, skipDomains, skipRelations, skipEnums, skipViews, skipSequences bool, skipTables string) {
// Validate input
if format == "pgsql" {
if connString == "" {
se.showErrorDialog("Error", "Connection string is required for PostgreSQL")
return
}
} else {
if filePath == "" {
se.showErrorDialog("Error", "File path is required for "+format)
return
}
// Expand home directory
if len(filePath) > 0 && filePath[0] == '~' {
home, err := os.UserHomeDir()
if err == nil {
filePath = filepath.Join(home, filePath[1:])
}
}
}
// Create reader
var reader readers.Reader
switch format {
case "dbml":
reader = rdbml.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "dctx":
reader = rdctx.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "drawdb":
reader = rdrawdb.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "graphql":
reader = rgraphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "json":
reader = rjson.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "yaml":
reader = ryaml.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "gorm":
reader = rgorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "bun":
reader = rbun.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "drizzle":
reader = rdrizzle.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "prisma":
reader = rprisma.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "typeorm":
reader = rtypeorm.NewReader(&readers.ReaderOptions{FilePath: filePath})
case "pgsql":
reader = rpgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
default:
se.showErrorDialog("Error", "Unsupported format: "+format)
return
}
// Read the database to import
importDb, err := reader.ReadDatabase()
if err != nil {
se.showErrorDialog("Import Error", fmt.Sprintf("Failed to read database: %v", err))
return
}
// Show confirmation dialog
se.showImportConfirmation(importDb, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
}
// showImportConfirmation shows a confirmation dialog before merging
func (se *SchemaEditor) showImportConfirmation(importDb *models.Database, skipDomains, skipRelations, skipEnums, skipViews, skipSequences bool, skipTables string) {
confirmText := fmt.Sprintf("Import & Merge Database?\n\nSource: %s\nTarget: %s\n\nThis will add missing schemas, tables, columns, and other objects from the source to your database.\n\nExisting items will NOT be modified.",
importDb.Name, se.db.Name)
modal := tview.NewModal().
SetText(confirmText).
AddButtons([]string{"Cancel", "Merge"}).
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
se.pages.RemovePage("import-confirm")
if buttonLabel == "Merge" {
se.performMerge(importDb, skipDomains, skipRelations, skipEnums, skipViews, skipSequences, skipTables)
}
})
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyEscape {
se.pages.RemovePage("import-confirm")
se.pages.SwitchToPage("import-database")
return nil
}
return event
})
se.pages.AddAndSwitchToPage("import-confirm", modal, true)
}
// performMerge performs the actual merge operation
func (se *SchemaEditor) performMerge(importDb *models.Database, skipDomains, skipRelations, skipEnums, skipViews, skipSequences bool, skipTables string) {
// Create merge options
opts := &merge.MergeOptions{
SkipDomains: skipDomains,
SkipRelations: skipRelations,
SkipEnums: skipEnums,
SkipViews: skipViews,
SkipSequences: skipSequences,
}
// Parse skip tables
if skipTables != "" {
opts.SkipTableNames = parseSkipTablesUI(skipTables)
}
// Perform the merge
result := merge.MergeDatabases(se.db, importDb, opts)
// Update the database timestamp
se.db.UpdateDate()
// Show success dialog with summary
summary := merge.GetMergeSummary(result)
se.showSuccessDialog("Import Complete", summary, func() {
se.pages.RemovePage("import-database")
se.pages.RemovePage("main")
se.pages.AddPage("main", se.createMainMenu(), true, true)
})
}
// getImportHelpText returns the help text for the import screen
func getImportHelpText() string {
return `Import & Merge: Adds missing schemas, tables, columns, and other objects to your existing database.
File-based formats: dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm
Database formats: pgsql (requires connection string)
Skip options: Check to exclude specific object types from the merge.`
}
func parseSkipTablesUI(skipTablesStr string) map[string]bool {
skipTables := make(map[string]bool)
if skipTablesStr == "" {
return skipTables
}
// Split by comma and trim whitespace
parts := strings.Split(skipTablesStr, ",")
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
// Store in lowercase for case-insensitive matching
skipTables[strings.ToLower(trimmed)] = true
}
}
return skipTables
}

View File

@@ -39,6 +39,9 @@ func (se *SchemaEditor) createMainMenu() tview.Primitive {
AddItem("Manage Domains", "View, create, edit, and delete domains", 'd', func() { AddItem("Manage Domains", "View, create, edit, and delete domains", 'd', func() {
se.showDomainList() se.showDomainList()
}). }).
AddItem("Import & Merge", "Import and merge schema from another database", 'i', func() {
se.showImportScreen()
}).
AddItem("Save Database", "Save database to file or database", 'w', func() { AddItem("Save Database", "Save database to file or database", 'w', func() {
se.showSaveScreen() se.showSaveScreen()
}). }).

View File

@@ -24,8 +24,8 @@ func (se *SchemaEditor) showSchemaList() {
schemaTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0) schemaTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
// Add header row with padding for full width // Add header row with padding for full width
headers := []string{"Name", "Sequence", "Total Tables", "Total Sequences", "Total Views", "Description"} headers := []string{"Name", "Sequence", "Total Tables", "Total Sequences", "Total Views", "GUID", "Description"}
headerWidths := []int{20, 15, 20, 20, 15} // Last column takes remaining space headerWidths := []int{20, 15, 20, 20, 15, 36} // Last column takes remaining space
for i, header := range headers { for i, header := range headers {
padding := "" padding := ""
if i < len(headerWidths) { if i < len(headerWidths) {
@@ -67,9 +67,14 @@ func (se *SchemaEditor) showSchemaList() {
viewsCell := tview.NewTableCell(viewsStr).SetSelectable(true) viewsCell := tview.NewTableCell(viewsStr).SetSelectable(true)
schemaTable.SetCell(row+1, 4, viewsCell) schemaTable.SetCell(row+1, 4, viewsCell)
// GUID - pad to 36 chars
guidStr := fmt.Sprintf("%-36s", schema.GUID)
guidCell := tview.NewTableCell(guidStr).SetSelectable(true)
schemaTable.SetCell(row+1, 5, guidCell)
// Description - no padding, takes remaining space // Description - no padding, takes remaining space
descCell := tview.NewTableCell(schema.Description).SetSelectable(true) descCell := tview.NewTableCell(schema.Description).SetSelectable(true)
schemaTable.SetCell(row+1, 5, descCell) schemaTable.SetCell(row+1, 6, descCell)
} }
schemaTable.SetTitle(" Schemas ").SetBorder(true).SetTitleAlign(tview.AlignLeft) schemaTable.SetTitle(" Schemas ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
@@ -307,6 +312,7 @@ func (se *SchemaEditor) showEditSchemaDialog(schemaIndex int) {
newName := schema.Name newName := schema.Name
newOwner := schema.Owner newOwner := schema.Owner
newDescription := schema.Description newDescription := schema.Description
newGUID := schema.GUID
form.AddInputField("Schema Name", schema.Name, 40, nil, func(value string) { form.AddInputField("Schema Name", schema.Name, 40, nil, func(value string) {
newName = value newName = value
@@ -320,9 +326,14 @@ func (se *SchemaEditor) showEditSchemaDialog(schemaIndex int) {
newDescription = value newDescription = value
}) })
form.AddInputField("GUID", schema.GUID, 40, nil, func(value string) {
newGUID = value
})
form.AddButton("Save", func() { form.AddButton("Save", func() {
// Apply changes using dataops // Apply changes using dataops
se.UpdateSchema(schemaIndex, newName, newOwner, newDescription) se.UpdateSchema(schemaIndex, newName, newOwner, newDescription)
se.db.Schemas[schemaIndex].GUID = newGUID
schema := se.db.Schemas[schemaIndex] schema := se.db.Schemas[schemaIndex]
se.pages.RemovePage("edit-schema") se.pages.RemovePage("edit-schema")

View File

@@ -24,8 +24,8 @@ func (se *SchemaEditor) showTableList() {
tableTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0) tableTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
// Add header row with padding for full width // Add header row with padding for full width
headers := []string{"Name", "Schema", "Sequence", "Total Columns", "Total Relations", "Total Indexes", "Description", "Comment"} headers := []string{"Name", "Schema", "Sequence", "Total Columns", "Total Relations", "Total Indexes", "GUID", "Description", "Comment"}
headerWidths := []int{18, 15, 12, 14, 15, 14, 0, 12} // Description gets remainder headerWidths := []int{18, 15, 12, 14, 15, 14, 36, 0, 12} // Description gets remainder
for i, header := range headers { for i, header := range headers {
padding := "" padding := ""
if i < len(headerWidths) && headerWidths[i] > 0 { if i < len(headerWidths) && headerWidths[i] > 0 {
@@ -82,14 +82,19 @@ func (se *SchemaEditor) showTableList() {
idxCell := tview.NewTableCell(idxStr).SetSelectable(true) idxCell := tview.NewTableCell(idxStr).SetSelectable(true)
tableTable.SetCell(row+1, 5, idxCell) tableTable.SetCell(row+1, 5, idxCell)
// GUID - pad to 36 chars
guidStr := fmt.Sprintf("%-36s", table.GUID)
guidCell := tview.NewTableCell(guidStr).SetSelectable(true)
tableTable.SetCell(row+1, 6, guidCell)
// Description - no padding, takes remaining space // Description - no padding, takes remaining space
descCell := tview.NewTableCell(table.Description).SetSelectable(true) descCell := tview.NewTableCell(table.Description).SetSelectable(true)
tableTable.SetCell(row+1, 6, descCell) tableTable.SetCell(row+1, 7, descCell)
// Comment - pad to 12 chars // Comment - pad to 12 chars
commentStr := fmt.Sprintf("%-12s", table.Comment) commentStr := fmt.Sprintf("%-12s", table.Comment)
commentCell := tview.NewTableCell(commentStr).SetSelectable(true) commentCell := tview.NewTableCell(commentStr).SetSelectable(true)
tableTable.SetCell(row+1, 7, commentCell) tableTable.SetCell(row+1, 8, commentCell)
} }
tableTable.SetTitle(" All Tables ").SetBorder(true).SetTitleAlign(tview.AlignLeft) tableTable.SetTitle(" All Tables ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
@@ -188,8 +193,8 @@ func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *mode
colTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0) colTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
// Add header row with padding for full width // Add header row with padding for full width
headers := []string{"Name", "Type", "Default", "KeyType", "Description"} headers := []string{"Name", "Type", "Default", "KeyType", "GUID", "Description"}
headerWidths := []int{20, 18, 15, 15} // Last column takes remaining space headerWidths := []int{20, 18, 15, 15, 36} // Last column takes remaining space
for i, header := range headers { for i, header := range headers {
padding := "" padding := ""
if i < len(headerWidths) { if i < len(headerWidths) {
@@ -237,9 +242,14 @@ func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *mode
keyTypeCell := tview.NewTableCell(keyTypeStr).SetSelectable(true) keyTypeCell := tview.NewTableCell(keyTypeStr).SetSelectable(true)
colTable.SetCell(row+1, 3, keyTypeCell) colTable.SetCell(row+1, 3, keyTypeCell)
// GUID - pad to 36 chars
guidStr := fmt.Sprintf("%-36s", column.GUID)
guidCell := tview.NewTableCell(guidStr).SetSelectable(true)
colTable.SetCell(row+1, 4, guidCell)
// Description // Description
descCell := tview.NewTableCell(column.Description).SetSelectable(true) descCell := tview.NewTableCell(column.Description).SetSelectable(true)
colTable.SetCell(row+1, 4, descCell) colTable.SetCell(row+1, 5, descCell)
} }
colTable.SetTitle(" Columns ").SetBorder(true).SetTitleAlign(tview.AlignLeft) colTable.SetTitle(" Columns ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
@@ -490,6 +500,7 @@ func (se *SchemaEditor) showEditTableDialog(schemaIndex, tableIndex int) {
// Local variables to collect changes // Local variables to collect changes
newName := table.Name newName := table.Name
newDescription := table.Description newDescription := table.Description
newGUID := table.GUID
form.AddInputField("Table Name", table.Name, 40, nil, func(value string) { form.AddInputField("Table Name", table.Name, 40, nil, func(value string) {
newName = value newName = value
@@ -499,9 +510,14 @@ func (se *SchemaEditor) showEditTableDialog(schemaIndex, tableIndex int) {
newDescription = value newDescription = value
}) })
form.AddInputField("GUID", table.GUID, 40, nil, func(value string) {
newGUID = value
})
form.AddButton("Save", func() { form.AddButton("Save", func() {
// Apply changes using dataops // Apply changes using dataops
se.UpdateTable(schemaIndex, tableIndex, newName, newDescription) se.UpdateTable(schemaIndex, tableIndex, newName, newDescription)
se.db.Schemas[schemaIndex].Tables[tableIndex].GUID = newGUID
table := se.db.Schemas[schemaIndex].Tables[tableIndex] table := se.db.Schemas[schemaIndex].Tables[tableIndex]
se.pages.RemovePage("edit-table") se.pages.RemovePage("edit-table")

View File

@@ -5,6 +5,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
// TemplateData represents the data passed to the template for code generation // TemplateData represents the data passed to the template for code generation
@@ -111,13 +112,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
tableName = schema + "." + table.Name tableName = schema + "." + table.Name
} }
// Generate model name: singularize and convert to PascalCase // Generate model name: Model + Schema + Table (all PascalCase)
singularTable := Singularize(table.Name) singularTable := Singularize(table.Name)
modelName := SnakeCaseToPascalCase(singularTable) tablePart := SnakeCaseToPascalCase(singularTable)
// Add "Model" prefix if not already present // Include schema name in model name
if !hasModelPrefix(modelName) { var modelName string
modelName = "Model" + modelName if schema != "" {
schemaPart := SnakeCaseToPascalCase(schema)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
} }
model := &ModelData{ model := &ModelData{
@@ -133,8 +138,10 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
// Find primary key // Find primary key
for _, col := range table.Columns { for _, col := range table.Columns {
if col.IsPrimaryKey { if col.IsPrimaryKey {
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name) // Sanitize column name to remove backticks
model.IDColumnName = col.Name safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.IDColumnName = safeName
// Check if PK type is a SQL type (contains resolvespec_common or sql_types) // Check if PK type is a SQL type (contains resolvespec_common or sql_types)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types") model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
@@ -146,6 +153,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
columns := sortColumns(table.Columns) columns := sortColumns(table.Columns)
for _, col := range columns { for _, col := range columns {
field := columnToField(col, table, typeMapper) field := columnToField(col, table, typeMapper)
// Check for name collision with generated methods and rename if needed
field.Name = resolveFieldNameCollision(field.Name)
model.Fields = append(model.Fields, field) model.Fields = append(model.Fields, field)
} }
@@ -154,10 +163,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
// columnToField converts a models.Column to FieldData // columnToField converts a models.Column to FieldData
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData { func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
fieldName := SnakeCaseToPascalCase(col.Name) // Sanitize column name first to remove backticks before generating field name
safeName := writers.SanitizeStructTagValue(col.Name)
fieldName := SnakeCaseToPascalCase(safeName)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
bunTag := typeMapper.BuildBunTag(col, table) bunTag := typeMapper.BuildBunTag(col, table)
jsonTag := col.Name // Use column name for JSON tag // Use same sanitized name for JSON tag
jsonTag := safeName
return &FieldData{ return &FieldData{
Name: fieldName, Name: fieldName,
@@ -184,9 +196,28 @@ func formatComment(description, comment string) string {
return comment return comment
} }
// hasModelPrefix checks if a name already has "Model" prefix // resolveFieldNameCollision checks if a field name conflicts with generated method names
func hasModelPrefix(name string) bool { // and adds an underscore suffix if there's a collision
return len(name) >= 5 && name[:5] == "Model" func resolveFieldNameCollision(fieldName string) string {
// List of method names that are generated by the template
reservedNames := map[string]bool{
"TableName": true,
"TableNameOnly": true,
"SchemaName": true,
"GetID": true,
"GetIDStr": true,
"SetID": true,
"UpdateID": true,
"GetIDName": true,
"GetPrefix": true,
}
// Check if field name conflicts with a reserved method name
if reservedNames[fieldName] {
return fieldName + "_"
}
return fieldName
} }
// sortColumns sorts columns by sequence, then by name // sortColumns sorts columns by sequence, then by name

View File

@@ -5,6 +5,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
// TypeMapper handles type conversions between SQL and Go types for Bun // TypeMapper handles type conversions between SQL and Go types for Bun
@@ -164,11 +165,14 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
var parts []string var parts []string
// Column name comes first (no prefix) // Column name comes first (no prefix)
parts = append(parts, column.Name) // Sanitize to remove backticks which would break struct tag syntax
safeName := writers.SanitizeStructTagValue(column.Name)
parts = append(parts, safeName)
// Add type if specified // Add type if specified
if column.Type != "" { if column.Type != "" {
typeStr := column.Type // Sanitize type to remove backticks
typeStr := writers.SanitizeStructTagValue(column.Type)
if column.Length > 0 { if column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length) typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if column.Precision > 0 { } else if column.Precision > 0 {
@@ -188,12 +192,17 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
// Default value // Default value
if column.Default != nil { if column.Default != nil {
parts = append(parts, fmt.Sprintf("default:%v", column.Default)) // Sanitize default value to remove backticks
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
} }
// Nullable (Bun uses nullzero for nullable fields) // Nullable (Bun uses nullzero for nullable fields)
// and notnull tag for explicitly non-nullable fields
if !column.NotNull && !column.IsPrimaryKey { if !column.NotNull && !column.IsPrimaryKey {
parts = append(parts, "nullzero") parts = append(parts, "nullzero")
} else if column.NotNull && !column.IsPrimaryKey {
parts = append(parts, "notnull")
} }
// Check for indexes (unique indexes should be added to tag) // Check for indexes (unique indexes should be added to tag)
@@ -260,7 +269,7 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common) // GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
func (tm *TypeMapper) GetSQLTypesImport() string { func (tm *TypeMapper) GetSQLTypesImport() string {
return "github.com/bitechdev/ResolveSpec/pkg/common" return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
} }
// GetBunImport returns the import path for Bun // GetBunImport returns the import path for Bun

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"go/format" "go/format"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
// Write output // Write output
return w.writeOutput(formatted) if err := w.writeOutput(formatted); err != nil {
return err
}
// Run go fmt on the output file
if w.options.OutputPath != "" {
w.runGoFmt(w.options.OutputPath)
}
return nil
} }
// writeMultiFile writes each table to a separate file // writeMultiFile writes each table to a separate file
@@ -207,13 +217,19 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
// Generate filename: sql_{schema}_{table}.go // Generate filename: sql_{schema}_{table}.go
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name) // Sanitize schema and table names to remove quotes, comments, and invalid characters
safeSchemaName := writers.SanitizeFilename(schema.Name)
safeTableName := writers.SanitizeFilename(table.Name)
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
filepath := filepath.Join(w.options.OutputPath, filename) filepath := filepath.Join(w.options.OutputPath, filename)
// Write file // Write file
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil { if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
return fmt.Errorf("failed to write file %s: %w", filename, err) return fmt.Errorf("failed to write file %s: %w", filename, err)
} }
// Run go fmt on the generated file
w.runGoFmt(filepath)
} }
} }
@@ -222,6 +238,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
// addRelationshipFields adds relationship fields to the model based on foreign keys // addRelationshipFields adds relationship fields to the model based on foreign keys
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) { func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
// Track used field names to detect duplicates
usedFieldNames := make(map[string]int)
// For each foreign key in this table, add a belongs-to/has-one relationship // For each foreign key in this table, add a belongs-to/has-one relationship
for _, constraint := range table.Constraints { for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint { if constraint.Type != models.ForeignKeyConstraint {
@@ -235,8 +254,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
// Create relationship field (has-one in Bun, similar to belongs-to in GORM) // Create relationship field (has-one in Bun, similar to belongs-to in GORM)
refModelName := w.getModelName(constraint.ReferencedTable) refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable) fieldName := w.generateHasOneFieldName(constraint)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one") relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
modelData.AddRelationshipField(&FieldData{ modelData.AddRelationshipField(&FieldData{
@@ -263,8 +283,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
// Check if this constraint references our table // Check if this constraint references our table
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name { if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
// Add has-many relationship // Add has-many relationship
otherModelName := w.getModelName(otherTable.Name) otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many") relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
modelData.AddRelationshipField(&FieldData{ modelData.AddRelationshipField(&FieldData{
@@ -295,22 +316,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
return nil return nil
} }
// getModelName generates the model name from a table name // getModelName generates the model name from schema and table name
func (w *Writer) getModelName(tableName string) string { func (w *Writer) getModelName(schemaName, tableName string) string {
singular := Singularize(tableName) singular := Singularize(tableName)
modelName := SnakeCaseToPascalCase(singular) tablePart := SnakeCaseToPascalCase(singular)
if !hasModelPrefix(modelName) { // Include schema name in model name
modelName = "Model" + modelName var modelName string
if schemaName != "" {
schemaPart := SnakeCaseToPascalCase(schemaName)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
} }
return modelName return modelName
} }
// generateRelationshipFieldName generates a field name for a relationship // generateHasOneFieldName generates a field name for has-one relationships
func (w *Writer) generateRelationshipFieldName(tableName string) string { // Uses the foreign key column name for uniqueness
// Use just the prefix (3 letters) for relationship fields func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
return GeneratePrefix(tableName) // Use the foreign key column name to ensure uniqueness
// If there are multiple columns, use the first one
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Convert to PascalCase for proper Go field naming
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
return "Rel" + SnakeCaseToPascalCase(columnName)
}
// Fallback to table-based prefix if no columns defined
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
}
// generateHasManyFieldName generates a field name for has-many relationships
// Uses the foreign key column name + source table name to avoid duplicates
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
// For has-many, we need to include the source table name to avoid duplicates
// e.g., multiple tables referencing the same column on this table
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Get the model name for the source table (pluralized)
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
// Remove "Model" prefix if present
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
// Convert column to PascalCase and combine with source table
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
columnPart := SnakeCaseToPascalCase(columnName)
return "Rel" + columnPart + Pluralize(sourceModelName)
}
// Fallback to table-based naming
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
return "Rel" + Pluralize(sourceModelName)
}
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
originalName := fieldName
count := usedNames[originalName]
if count > 0 {
// Name is already used, add numeric suffix
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
}
// Increment the counter for this base name
usedNames[originalName]++
return fieldName
} }
// getPackageName returns the package name from options or defaults to "models" // getPackageName returns the package name from options or defaults to "models"
@@ -341,6 +417,15 @@ func (w *Writer) writeOutput(content string) error {
return nil return nil
} }
// runGoFmt runs go fmt on the specified file
func (w *Writer) runGoFmt(filepath string) {
cmd := exec.Command("gofmt", "-w", filepath)
if err := cmd.Run(); err != nil {
// Don't fail the whole operation if gofmt fails, just warn
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
}
}
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path // shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
func (w *Writer) shouldUseMultiFile() bool { func (w *Writer) shouldUseMultiFile() bool {
// Check if multi_file is explicitly set in metadata // Check if multi_file is explicitly set in metadata
@@ -386,6 +471,7 @@ func (w *Writer) createDatabaseRef(db *models.Database) *models.Database {
DatabaseVersion: db.DatabaseVersion, DatabaseVersion: db.DatabaseVersion,
SourceFormat: db.SourceFormat, SourceFormat: db.SourceFormat,
Schemas: nil, // Don't include schemas to avoid circular reference Schemas: nil, // Don't include schemas to avoid circular reference
GUID: db.GUID,
} }
} }
@@ -402,5 +488,6 @@ func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *mo
Sequence: schema.Sequence, Sequence: schema.Sequence,
RefDatabase: w.createDatabaseRef(db), // Include database ref RefDatabase: w.createDatabaseRef(db), // Include database ref
Tables: nil, // Don't include tables to avoid circular reference Tables: nil, // Don't include tables to avoid circular reference
GUID: schema.GUID,
} }
} }

View File

@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
// Verify key elements are present // Verify key elements are present
expectations := []string{ expectations := []string{
"package models", "package models",
"type ModelUser struct", "type ModelPublicUser struct",
"bun.BaseModel", "bun.BaseModel",
"table:public.users", "table:public.users",
"alias:users", "alias:users",
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
"resolvespec_common.SqlTime", "resolvespec_common.SqlTime",
"bun:\"id", "bun:\"id",
"bun:\"email", "bun:\"email",
"func (m ModelUser) TableName() string", "func (m ModelPublicUser) TableName() string",
"return \"public.users\"", "return \"public.users\"",
"func (m ModelUser) GetID() int64", "func (m ModelPublicUser) GetID() int64",
} }
for _, expected := range expectations { for _, expected := range expectations {
@@ -175,12 +175,378 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
postsStr := string(postsContent) postsStr := string(postsContent)
// Verify relationship is present with Bun format // Verify relationship is present with Bun format
if !strings.Contains(postsStr, "USE") { // Should now be RelUserID (has-one) instead of USE
t.Errorf("Missing relationship field USE") if !strings.Contains(postsStr, "RelUserID") {
t.Errorf("Missing relationship field RelUserID (new naming convention)")
} }
if !strings.Contains(postsStr, "rel:has-one") { if !strings.Contains(postsStr, "rel:has-one") {
t.Errorf("Missing Bun relationship tag: %s", postsStr) t.Errorf("Missing Bun relationship tag: %s", postsStr)
} }
// Check users file contains has-many relationship
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
if err != nil {
t.Fatalf("Failed to read users file: %v", err)
}
usersStr := string(usersContent)
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
}
}
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
// Test scenario: api_event table with multiple foreign keys to filepointer table
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, filepointer)
// API event table with two foreign keys to filepointer
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
Name: "rid_filepointer_request",
Type: "bigint",
NotNull: false,
}
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
Name: "rid_filepointer_response",
Type: "bigint",
NotNull: false,
}
// Add constraints
apiEvent.Constraints["fk_request"] = &models.Constraint{
Name: "fk_request",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_request"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
apiEvent.Constraints["fk_response"] = &models.Constraint{
Name: "fk_response",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_response"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_event file
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
if err != nil {
t.Fatalf("Failed to read api_event file: %v", err)
}
contentStr := string(apiEventContent)
// Verify both relationships have unique names based on column names
expectations := []struct {
fieldName string
tag string
}{
{"RelRIDFilepointerRequest", "join:rid_filepointer_request=id_filepointer"},
{"RelRIDFilepointerResponse", "join:rid_filepointer_response=id_filepointer"},
}
for _, exp := range expectations {
if !strings.Contains(contentStr, exp.fieldName) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
}
if !strings.Contains(contentStr, exp.tag) {
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
}
}
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
}
// Also verify has-many relationships on filepointer table
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
if err != nil {
t.Fatalf("Failed to read filepointer file: %v", err)
}
filepointerStr := string(filepointerContent)
// Should have two different has-many relationships with unique names
hasManyExpectations := []string{
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
}
for _, exp := range hasManyExpectations {
if !strings.Contains(filepointerStr, exp) {
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
}
}
}
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Owner table
owner := models.InitTable("owner", "org")
owner.Columns["id_owner"] = &models.Column{
Name: "id_owner",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, owner)
// API Provider table
apiProvider := models.InitTable("api_provider", "org")
apiProvider.Columns["id_api_provider"] = &models.Column{
Name: "id_api_provider",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiProvider.Columns["rid_owner"] = &models.Column{
Name: "rid_owner",
Type: "bigint",
NotNull: true,
}
apiProvider.Constraints["fk_owner"] = &models.Constraint{
Name: "fk_owner",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_owner"},
ReferencedTable: "owner",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_owner"},
}
schema.Tables = append(schema.Tables, apiProvider)
// Login table
login := models.InitTable("login", "org")
login.Columns["id_login"] = &models.Column{
Name: "id_login",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
login.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
login.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, login)
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
filepointer.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, filepointer)
// API Event table
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_provider file
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
if err != nil {
t.Fatalf("Failed to read api_provider file: %v", err)
}
contentStr := string(apiProviderContent)
// Verify all has-many relationships have unique names
hasManyExpectations := []string{
"RelRIDAPIProviderOrgLogins", // Has many via Login
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
"RelRIDOwner", // Has one via rid_owner
}
for _, exp := range hasManyExpectations {
if !strings.Contains(contentStr, exp) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
}
}
// Verify NO duplicate field names
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
count := strings.Count(contentStr, "RelRIDAPIProvider")
if count != 3 {
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
}
// Verify no duplicate declarations (would cause compilation error)
duplicatePattern := "RelRIDAPIProviders []*Model"
if strings.Contains(contentStr, duplicatePattern) {
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
}
}
func TestWriter_FieldNameCollision(t *testing.T) {
// Test scenario: table with columns that would conflict with generated method names
table := models.InitTable("audit_table", "audit")
table.Columns["id_audit_table"] = &models.Column{
Name: "id_audit_table",
Type: "smallint",
NotNull: true,
IsPrimaryKey: true,
Sequence: 1,
}
table.Columns["table_name"] = &models.Column{
Name: "table_name",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 2,
}
table.Columns["table_schema"] = &models.Column{
Name: "table_schema",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 3,
}
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
// Read the generated file
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
// Verify that TableName field was renamed to TableName_ to avoid collision
if !strings.Contains(generated, "TableName_") {
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
}
// Verify the struct tag still references the correct database column
if !strings.Contains(generated, `bun:"table_name,`) {
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
}
// Verify the TableName() method still exists and doesn't conflict
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
}
// Verify NO field named just "TableName" (without underscore)
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
}
} }
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) { func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {

View File

@@ -126,7 +126,15 @@ func (w *Writer) tableToDBML(t *models.Table) string {
attrs = append(attrs, "increment") attrs = append(attrs, "increment")
} }
if column.Default != nil { if column.Default != nil {
attrs = append(attrs, fmt.Sprintf("default: `%v`", column.Default)) // Check if default value contains backticks (DBML expressions like `now()`)
defaultStr := fmt.Sprintf("%v", column.Default)
if strings.HasPrefix(defaultStr, "`") && strings.HasSuffix(defaultStr, "`") {
// Already an expression with backticks, use as-is
attrs = append(attrs, fmt.Sprintf("default: %s", defaultStr))
} else {
// Regular value, wrap in single quotes
attrs = append(attrs, fmt.Sprintf("default: '%v'", column.Default))
}
} }
if len(attrs) > 0 { if len(attrs) > 0 {

View File

@@ -133,7 +133,11 @@ func (w *Writer) mapTableFields(table *models.Table) models.DCTXTable {
prefix = table.Name[:3] prefix = table.Name[:3]
} }
tableGuid := w.newGUID() // Use GUID from model if available, otherwise generate a new one
tableGuid := table.GUID
if tableGuid == "" {
tableGuid = w.newGUID()
}
w.tableGuidMap[table.Name] = tableGuid w.tableGuidMap[table.Name] = tableGuid
dctxTable := models.DCTXTable{ dctxTable := models.DCTXTable{
@@ -171,7 +175,11 @@ func (w *Writer) mapTableKeys(table *models.Table) []models.DCTXKey {
} }
func (w *Writer) mapField(column *models.Column) models.DCTXField { func (w *Writer) mapField(column *models.Column) models.DCTXField {
guid := w.newGUID() // Use GUID from model if available, otherwise generate a new one
guid := column.GUID
if guid == "" {
guid = w.newGUID()
}
fieldKey := fmt.Sprintf("%s.%s", column.Table, column.Name) fieldKey := fmt.Sprintf("%s.%s", column.Table, column.Name)
w.fieldGuidMap[fieldKey] = guid w.fieldGuidMap[fieldKey] = guid
@@ -209,7 +217,11 @@ func (w *Writer) mapDataType(dataType string) string {
} }
func (w *Writer) mapKey(index *models.Index, table *models.Table) models.DCTXKey { func (w *Writer) mapKey(index *models.Index, table *models.Table) models.DCTXKey {
guid := w.newGUID() // Use GUID from model if available, otherwise generate a new one
guid := index.GUID
if guid == "" {
guid = w.newGUID()
}
keyKey := fmt.Sprintf("%s.%s", table.Name, index.Name) keyKey := fmt.Sprintf("%s.%s", table.Name, index.Name)
w.keyGuidMap[keyKey] = guid w.keyGuidMap[keyKey] = guid
@@ -344,7 +356,7 @@ func (w *Writer) mapRelation(rel *models.Relationship, schema *models.Schema) mo
} }
return models.DCTXRelation{ return models.DCTXRelation{
Guid: w.newGUID(), Guid: rel.GUID, // Use GUID from relationship model
PrimaryTable: w.tableGuidMap[rel.ToTable], // GUID of the 'to' table (e.g., users) PrimaryTable: w.tableGuidMap[rel.ToTable], // GUID of the 'to' table (e.g., users)
ForeignTable: w.tableGuidMap[rel.FromTable], // GUID of the 'from' table (e.g., posts) ForeignTable: w.tableGuidMap[rel.FromTable], // GUID of the 'from' table (e.g., posts)
PrimaryKey: primaryKeyGUID, PrimaryKey: primaryKeyGUID,

View File

@@ -196,7 +196,9 @@ func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *
} }
// Generate filename: {tableName}.ts // Generate filename: {tableName}.ts
filename := filepath.Join(w.options.OutputPath, table.Name+".ts") // Sanitize table name to remove quotes, comments, and invalid characters
safeTableName := writers.SanitizeFilename(table.Name)
filename := filepath.Join(w.options.OutputPath, safeTableName+".ts")
return os.WriteFile(filename, []byte(code), 0644) return os.WriteFile(filename, []byte(code), 0644)
} }

View File

@@ -4,6 +4,7 @@ import (
"sort" "sort"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
// TemplateData represents the data passed to the template for code generation // TemplateData represents the data passed to the template for code generation
@@ -24,6 +25,7 @@ type ModelData struct {
Fields []*FieldData Fields []*FieldData
Config *MethodConfig Config *MethodConfig
PrimaryKeyField string // Name of the primary key field PrimaryKeyField string // Name of the primary key field
PrimaryKeyType string // Go type of the primary key field
IDColumnName string // Name of the ID column in database IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix Prefix string // 3-letter prefix
} }
@@ -109,13 +111,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
tableName = schema + "." + table.Name tableName = schema + "." + table.Name
} }
// Generate model name: singularize and convert to PascalCase // Generate model name: Model + Schema + Table (all PascalCase)
singularTable := Singularize(table.Name) singularTable := Singularize(table.Name)
modelName := SnakeCaseToPascalCase(singularTable) tablePart := SnakeCaseToPascalCase(singularTable)
// Add "Model" prefix if not already present // Include schema name in model name
if !hasModelPrefix(modelName) { var modelName string
modelName = "Model" + modelName if schema != "" {
schemaPart := SnakeCaseToPascalCase(schema)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
} }
model := &ModelData{ model := &ModelData{
@@ -131,8 +137,11 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
// Find primary key // Find primary key
for _, col := range table.Columns { for _, col := range table.Columns {
if col.IsPrimaryKey { if col.IsPrimaryKey {
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name) // Sanitize column name to remove backticks
model.IDColumnName = col.Name safeName := writers.SanitizeStructTagValue(col.Name)
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
model.IDColumnName = safeName
break break
} }
} }
@@ -141,6 +150,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
columns := sortColumns(table.Columns) columns := sortColumns(table.Columns)
for _, col := range columns { for _, col := range columns {
field := columnToField(col, table, typeMapper) field := columnToField(col, table, typeMapper)
// Check for name collision with generated methods and rename if needed
field.Name = resolveFieldNameCollision(field.Name)
model.Fields = append(model.Fields, field) model.Fields = append(model.Fields, field)
} }
@@ -149,10 +160,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
// columnToField converts a models.Column to FieldData // columnToField converts a models.Column to FieldData
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData { func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
fieldName := SnakeCaseToPascalCase(col.Name) // Sanitize column name first to remove backticks before generating field name
safeName := writers.SanitizeStructTagValue(col.Name)
fieldName := SnakeCaseToPascalCase(safeName)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull) goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
gormTag := typeMapper.BuildGormTag(col, table) gormTag := typeMapper.BuildGormTag(col, table)
jsonTag := col.Name // Use column name for JSON tag // Use same sanitized name for JSON tag
jsonTag := safeName
return &FieldData{ return &FieldData{
Name: fieldName, Name: fieldName,
@@ -179,9 +193,28 @@ func formatComment(description, comment string) string {
return comment return comment
} }
// hasModelPrefix checks if a name already has "Model" prefix // resolveFieldNameCollision checks if a field name conflicts with generated method names
func hasModelPrefix(name string) bool { // and adds an underscore suffix if there's a collision
return len(name) >= 5 && name[:5] == "Model" func resolveFieldNameCollision(fieldName string) string {
// List of method names that are generated by the template
reservedNames := map[string]bool{
"TableName": true,
"TableNameOnly": true,
"SchemaName": true,
"GetID": true,
"GetIDStr": true,
"SetID": true,
"UpdateID": true,
"GetIDName": true,
"GetPrefix": true,
}
// Check if field name conflicts with a reserved method name
if reservedNames[fieldName] {
return fieldName + "_"
}
return fieldName
} }
// sortColumns sorts columns by sequence, then by name // sortColumns sorts columns by sequence, then by name

View File

@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
{{if and .Config.GenerateUpdateID .PrimaryKeyField}} {{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value // UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) { func (m *{{.Name}}) UpdateID(newid int64) {
m.{{.PrimaryKeyField}} = int32(newid) m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
} }
{{end}} {{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}} {{if and .Config.GenerateGetIDName .IDColumnName}}

View File

@@ -5,6 +5,7 @@ import (
"strings" "strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
) )
// TypeMapper handles type conversions between SQL and Go types // TypeMapper handles type conversions between SQL and Go types
@@ -199,12 +200,15 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
var parts []string var parts []string
// Always include column name (lowercase as per user requirement) // Always include column name (lowercase as per user requirement)
parts = append(parts, fmt.Sprintf("column:%s", column.Name)) // Sanitize to remove backticks which would break struct tag syntax
safeName := writers.SanitizeStructTagValue(column.Name)
parts = append(parts, fmt.Sprintf("column:%s", safeName))
// Add type if specified // Add type if specified
if column.Type != "" { if column.Type != "" {
// Include length, precision, scale if present // Include length, precision, scale if present
typeStr := column.Type // Sanitize type to remove backticks
typeStr := writers.SanitizeStructTagValue(column.Type)
if column.Length > 0 { if column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length) typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if column.Precision > 0 { } else if column.Precision > 0 {
@@ -234,7 +238,9 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
// Default value // Default value
if column.Default != nil { if column.Default != nil {
parts = append(parts, fmt.Sprintf("default:%v", column.Default)) // Sanitize default value to remove backticks
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
} }
// Check for unique constraint // Check for unique constraint
@@ -331,5 +337,5 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
// GetSQLTypesImport returns the import path for sql_types // GetSQLTypesImport returns the import path for sql_types
func (tm *TypeMapper) GetSQLTypesImport() string { func (tm *TypeMapper) GetSQLTypesImport() string {
return "github.com/bitechdev/ResolveSpec/pkg/common/sql_types" return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
} }

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"go/format" "go/format"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
} }
// Write output // Write output
return w.writeOutput(formatted) if err := w.writeOutput(formatted); err != nil {
return err
}
// Run go fmt on the output file
if w.options.OutputPath != "" {
w.runGoFmt(w.options.OutputPath)
}
return nil
} }
// writeMultiFile writes each table to a separate file // writeMultiFile writes each table to a separate file
@@ -201,13 +211,19 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
} }
// Generate filename: sql_{schema}_{table}.go // Generate filename: sql_{schema}_{table}.go
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name) // Sanitize schema and table names to remove quotes, comments, and invalid characters
safeSchemaName := writers.SanitizeFilename(schema.Name)
safeTableName := writers.SanitizeFilename(table.Name)
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
filepath := filepath.Join(w.options.OutputPath, filename) filepath := filepath.Join(w.options.OutputPath, filename)
// Write file // Write file
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil { if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
return fmt.Errorf("failed to write file %s: %w", filename, err) return fmt.Errorf("failed to write file %s: %w", filename, err)
} }
// Run go fmt on the generated file
w.runGoFmt(filepath)
} }
} }
@@ -216,6 +232,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
// addRelationshipFields adds relationship fields to the model based on foreign keys // addRelationshipFields adds relationship fields to the model based on foreign keys
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) { func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
// Track used field names to detect duplicates
usedFieldNames := make(map[string]int)
// For each foreign key in this table, add a belongs-to relationship // For each foreign key in this table, add a belongs-to relationship
for _, constraint := range table.Constraints { for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint { if constraint.Type != models.ForeignKeyConstraint {
@@ -229,8 +248,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
} }
// Create relationship field (belongs-to) // Create relationship field (belongs-to)
refModelName := w.getModelName(constraint.ReferencedTable) refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable) fieldName := w.generateBelongsToFieldName(constraint)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false) relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
modelData.AddRelationshipField(&FieldData{ modelData.AddRelationshipField(&FieldData{
@@ -257,8 +277,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
// Check if this constraint references our table // Check if this constraint references our table
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name { if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
// Add has-many relationship // Add has-many relationship
otherModelName := w.getModelName(otherTable.Name) otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true) relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
modelData.AddRelationshipField(&FieldData{ modelData.AddRelationshipField(&FieldData{
@@ -289,22 +310,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
return nil return nil
} }
// getModelName generates the model name from a table name // getModelName generates the model name from schema and table name
func (w *Writer) getModelName(tableName string) string { func (w *Writer) getModelName(schemaName, tableName string) string {
singular := Singularize(tableName) singular := Singularize(tableName)
modelName := SnakeCaseToPascalCase(singular) tablePart := SnakeCaseToPascalCase(singular)
if !hasModelPrefix(modelName) { // Include schema name in model name
modelName = "Model" + modelName var modelName string
if schemaName != "" {
schemaPart := SnakeCaseToPascalCase(schemaName)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
} }
return modelName return modelName
} }
// generateRelationshipFieldName generates a field name for a relationship // generateBelongsToFieldName generates a field name for belongs-to relationships
func (w *Writer) generateRelationshipFieldName(tableName string) string { // Uses the foreign key column name for uniqueness
// Use just the prefix (3 letters) for relationship fields func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) string {
return GeneratePrefix(tableName) // Use the foreign key column name to ensure uniqueness
// If there are multiple columns, use the first one
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Convert to PascalCase for proper Go field naming
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
return "Rel" + SnakeCaseToPascalCase(columnName)
}
// Fallback to table-based prefix if no columns defined
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
}
// generateHasManyFieldName generates a field name for has-many relationships
// Uses the foreign key column name + source table name to avoid duplicates
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
// For has-many, we need to include the source table name to avoid duplicates
// e.g., multiple tables referencing the same column on this table
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Get the model name for the source table (pluralized)
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
// Remove "Model" prefix if present
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
// Convert column to PascalCase and combine with source table
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
columnPart := SnakeCaseToPascalCase(columnName)
return "Rel" + columnPart + Pluralize(sourceModelName)
}
// Fallback to table-based naming
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
return "Rel" + Pluralize(sourceModelName)
}
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
originalName := fieldName
count := usedNames[originalName]
if count > 0 {
// Name is already used, add numeric suffix
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
}
// Increment the counter for this base name
usedNames[originalName]++
return fieldName
} }
// getPackageName returns the package name from options or defaults to "models" // getPackageName returns the package name from options or defaults to "models"
@@ -335,6 +411,15 @@ func (w *Writer) writeOutput(content string) error {
return nil return nil
} }
// runGoFmt runs go fmt on the specified file
func (w *Writer) runGoFmt(filepath string) {
cmd := exec.Command("gofmt", "-w", filepath)
if err := cmd.Run(); err != nil {
// Don't fail the whole operation if gofmt fails, just warn
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
}
}
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path // shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
func (w *Writer) shouldUseMultiFile() bool { func (w *Writer) shouldUseMultiFile() bool {
// Check if multi_file is explicitly set in metadata // Check if multi_file is explicitly set in metadata
@@ -380,6 +465,7 @@ func (w *Writer) createDatabaseRef(db *models.Database) *models.Database {
DatabaseVersion: db.DatabaseVersion, DatabaseVersion: db.DatabaseVersion,
SourceFormat: db.SourceFormat, SourceFormat: db.SourceFormat,
Schemas: nil, // Don't include schemas to avoid circular reference Schemas: nil, // Don't include schemas to avoid circular reference
GUID: db.GUID,
} }
} }
@@ -396,5 +482,6 @@ func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *mo
Sequence: schema.Sequence, Sequence: schema.Sequence,
RefDatabase: w.createDatabaseRef(db), // Include database ref RefDatabase: w.createDatabaseRef(db), // Include database ref
Tables: nil, // Don't include tables to avoid circular reference Tables: nil, // Don't include tables to avoid circular reference
GUID: schema.GUID,
} }
} }

View File

@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
// Verify key elements are present // Verify key elements are present
expectations := []string{ expectations := []string{
"package models", "package models",
"type ModelUser struct", "type ModelPublicUser struct",
"ID", "ID",
"int64", "int64",
"Email", "Email",
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
"time.Time", "time.Time",
"gorm:\"column:id", "gorm:\"column:id",
"gorm:\"column:email", "gorm:\"column:email",
"func (m ModelUser) TableName() string", "func (m ModelPublicUser) TableName() string",
"return \"public.users\"", "return \"public.users\"",
"func (m ModelUser) GetID() int64", "func (m ModelPublicUser) GetID() int64",
} }
for _, expected := range expectations { for _, expected := range expectations {
@@ -164,9 +164,437 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
t.Fatalf("Failed to read posts file: %v", err) t.Fatalf("Failed to read posts file: %v", err)
} }
if !strings.Contains(string(postsContent), "USE *ModelUser") { postsStr := string(postsContent)
// Relationship field should be present
t.Logf("Posts content:\n%s", string(postsContent)) // Verify relationship is present with new naming convention
// Should now be RelUserID (belongs-to) instead of USE
if !strings.Contains(postsStr, "RelUserID") {
t.Errorf("Missing relationship field RelUserID (new naming convention)")
}
// Check users file contains has-many relationship
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
if err != nil {
t.Fatalf("Failed to read users file: %v", err)
}
usersStr := string(usersContent)
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
}
}
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
// Test scenario: api_event table with multiple foreign keys to filepointer table
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, filepointer)
// API event table with two foreign keys to filepointer
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
Name: "rid_filepointer_request",
Type: "bigint",
NotNull: false,
}
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
Name: "rid_filepointer_response",
Type: "bigint",
NotNull: false,
}
// Add constraints
apiEvent.Constraints["fk_request"] = &models.Constraint{
Name: "fk_request",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_request"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
apiEvent.Constraints["fk_response"] = &models.Constraint{
Name: "fk_response",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_filepointer_response"},
ReferencedTable: "filepointer",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_filepointer"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_event file
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
if err != nil {
t.Fatalf("Failed to read api_event file: %v", err)
}
contentStr := string(apiEventContent)
// Verify both relationships have unique names based on column names
expectations := []struct {
fieldName string
tag string
}{
{"RelRIDFilepointerRequest", "foreignKey:RIDFilepointerRequest"},
{"RelRIDFilepointerResponse", "foreignKey:RIDFilepointerResponse"},
}
for _, exp := range expectations {
if !strings.Contains(contentStr, exp.fieldName) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
}
if !strings.Contains(contentStr, exp.tag) {
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
}
}
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
}
// Also verify has-many relationships on filepointer table
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
if err != nil {
t.Fatalf("Failed to read filepointer file: %v", err)
}
filepointerStr := string(filepointerContent)
// Should have two different has-many relationships with unique names
hasManyExpectations := []string{
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
}
for _, exp := range hasManyExpectations {
if !strings.Contains(filepointerStr, exp) {
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
}
}
}
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
db := models.InitDatabase("testdb")
schema := models.InitSchema("org")
// Owner table
owner := models.InitTable("owner", "org")
owner.Columns["id_owner"] = &models.Column{
Name: "id_owner",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, owner)
// API Provider table
apiProvider := models.InitTable("api_provider", "org")
apiProvider.Columns["id_api_provider"] = &models.Column{
Name: "id_api_provider",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiProvider.Columns["rid_owner"] = &models.Column{
Name: "rid_owner",
Type: "bigint",
NotNull: true,
}
apiProvider.Constraints["fk_owner"] = &models.Constraint{
Name: "fk_owner",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_owner"},
ReferencedTable: "owner",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_owner"},
}
schema.Tables = append(schema.Tables, apiProvider)
// Login table
login := models.InitTable("login", "org")
login.Columns["id_login"] = &models.Column{
Name: "id_login",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
login.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
login.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, login)
// Filepointer table
filepointer := models.InitTable("filepointer", "org")
filepointer.Columns["id_filepointer"] = &models.Column{
Name: "id_filepointer",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
filepointer.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, filepointer)
// API Event table
apiEvent := models.InitTable("api_event", "org")
apiEvent.Columns["id_api_event"] = &models.Column{
Name: "id_api_event",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
}
apiEvent.Columns["rid_api_provider"] = &models.Column{
Name: "rid_api_provider",
Type: "bigint",
NotNull: true,
}
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
Name: "fk_api_provider",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_api_provider"},
ReferencedTable: "api_provider",
ReferencedSchema: "org",
ReferencedColumns: []string{"id_api_provider"},
}
schema.Tables = append(schema.Tables, apiEvent)
db.Schemas = append(db.Schemas, schema)
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Read the api_provider file
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
if err != nil {
t.Fatalf("Failed to read api_provider file: %v", err)
}
contentStr := string(apiProviderContent)
// Verify all has-many relationships have unique names
hasManyExpectations := []string{
"RelRIDAPIProviderOrgLogins", // Has many via Login
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
"RelRIDOwner", // Belongs to via rid_owner
}
for _, exp := range hasManyExpectations {
if !strings.Contains(contentStr, exp) {
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
}
}
// Verify NO duplicate field names
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
count := strings.Count(contentStr, "RelRIDAPIProvider")
if count != 3 {
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
}
// Verify no duplicate declarations (would cause compilation error)
duplicatePattern := "RelRIDAPIProviders []*Model"
if strings.Contains(contentStr, duplicatePattern) {
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
}
}
func TestWriter_FieldNameCollision(t *testing.T) {
// Test scenario: table with columns that would conflict with generated method names
table := models.InitTable("audit_table", "audit")
table.Columns["id_audit_table"] = &models.Column{
Name: "id_audit_table",
Type: "smallint",
NotNull: true,
IsPrimaryKey: true,
Sequence: 1,
}
table.Columns["table_name"] = &models.Column{
Name: "table_name",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 2,
}
table.Columns["table_schema"] = &models.Column{
Name: "table_schema",
Type: "varchar",
Length: 100,
NotNull: true,
Sequence: 3,
}
// Create writer
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
// Read the generated file
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
// Verify that TableName field was renamed to TableName_ to avoid collision
if !strings.Contains(generated, "TableName_") {
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
}
// Verify the struct tag still references the correct database column
if !strings.Contains(generated, `gorm:"column:table_name;`) {
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
}
// Verify the TableName() method still exists and doesn't conflict
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
}
// Verify NO field named just "TableName" (without underscore)
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
}
}
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
// Test scenario: tables with different primary key types
tests := []struct {
name string
pkType string
expectedPK string
castType string
}{
{"int32_pk", "int", "int32", "int32(newid)"},
{"int16_pk", "smallint", "int16", "int16(newid)"},
{"int64_pk", "bigint", "int64", "int64(newid)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table := models.InitTable("test_table", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: tt.pkType,
NotNull: true,
IsPrimaryKey: true,
}
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: filepath.Join(tmpDir, "test.go"),
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
// Verify UpdateID method has correct type cast
if !strings.Contains(generated, tt.castType) {
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
}
// Verify no invalid int32(newid) for non-int32 types
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
}
// Verify UpdateID parameter is int64 (for consistency)
if !strings.Contains(generated, "UpdateID(newid int64)") {
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
}
})
} }
} }

View File

@@ -1,6 +1,9 @@
package writers package writers
import ( import (
"regexp"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/models"
) )
@@ -28,3 +31,56 @@ type WriterOptions struct {
// Additional options can be added here as needed // Additional options can be added here as needed
Metadata map[string]interface{} Metadata map[string]interface{}
} }
// SanitizeFilename removes quotes, comments, and invalid characters from identifiers
// to make them safe for use in filenames. This handles:
// - Double and single quotes: "table_name" or 'table_name' -> table_name
// - DBML comments: table [note: 'description'] -> table
// - Invalid filename characters: replaced with underscores
func SanitizeFilename(name string) string {
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
name = commentRegex.ReplaceAllString(name, "")
// Remove quotes (both single and double)
name = strings.ReplaceAll(name, `"`, "")
name = strings.ReplaceAll(name, `'`, "")
// Remove backticks (MySQL style identifiers)
name = strings.ReplaceAll(name, "`", "")
// Replace invalid filename characters with underscores
// Invalid chars: / \ : * ? " < > | and control characters
invalidChars := regexp.MustCompile(`[/\\:*?"<>|\x00-\x1f\x7f]`)
name = invalidChars.ReplaceAllString(name, "_")
// Trim whitespace and consecutive underscores
name = strings.TrimSpace(name)
name = regexp.MustCompile(`_+`).ReplaceAllString(name, "_")
name = strings.Trim(name, "_")
return name
}
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
// Go struct tags are delimited by backticks, so any backtick in the value would break the syntax.
// This function:
// - Removes DBML/DCTX comments in brackets
// - Removes all quotes (double, single, and backticks)
// - Returns a clean identifier safe for use in struct tags and field names
func SanitizeStructTagValue(value string) string {
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
value = commentRegex.ReplaceAllString(value, "")
// Trim whitespace
value = strings.TrimSpace(value)
// Remove all quotes: backticks, double quotes, and single quotes
// This ensures the value is clean for use as Go identifiers and struct tag values
value = strings.ReplaceAll(value, "`", "")
value = strings.ReplaceAll(value, `"`, "")
value = strings.ReplaceAll(value, `'`, "")
return value
}

View File

@@ -0,0 +1,5 @@
// First file - users table basic structure
Table public.users {
id bigint [pk, increment]
email varchar(255) [unique, not null]
}

View File

@@ -0,0 +1,8 @@
// Second file - posts table
Table public.posts {
id bigint [pk, increment]
user_id bigint [not null]
title varchar(200) [not null]
content text
created_at timestamp [not null]
}

View File

@@ -0,0 +1,5 @@
// Third file - adds more columns to users table (tests merging)
Table public.users {
name varchar(100)
created_at timestamp [not null]
}

View File

@@ -0,0 +1,10 @@
// File with commented-out refs - should load last
// Contains relationships that depend on earlier tables
// Ref: public.posts.user_id > public.users.id [ondelete: CASCADE]
Table public.comments {
id bigint [pk, increment]
post_id bigint [not null]
content text [not null]
}