Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 77436757c8 | |||
| 5e6f03e412 | |||
| 1dcbc79387 | |||
| 59c4a5ebf8 | |||
| 091e1913ee | |||
| 0e6e94797c | |||
| a033349c76 | |||
| 466d657ea7 | |||
| 47bf748fd5 | |||
| 88589e00e7 | |||
| 4cdccde9cf | |||
| aba22cb574 | |||
| d0630b4899 | |||
| c9eed9b794 | |||
|
|
5fb09b78c3 | ||
| 5d9770b430 | |||
| f2d500f98d | |||
| 2ec9991324 |
11
.github/workflows/release.yml
vendored
11
.github/workflows/release.yml
vendored
@@ -25,6 +25,7 @@ jobs:
|
||||
id: get_version
|
||||
run: |
|
||||
echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||
echo "BUILD_DATE=$(date -u '+%Y-%m-%d %H:%M:%S UTC')" >> $GITHUB_OUTPUT
|
||||
echo "Version: ${GITHUB_REF#refs/tags/}"
|
||||
|
||||
- name: Build binaries for multiple platforms
|
||||
@@ -32,19 +33,19 @@ jobs:
|
||||
mkdir -p dist
|
||||
|
||||
# Linux AMD64
|
||||
GOOS=linux GOARCH=amd64 go build -o dist/relspec-linux-amd64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec
|
||||
GOOS=linux GOARCH=amd64 go build -o dist/relspec-linux-amd64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
|
||||
|
||||
# Linux ARM64
|
||||
GOOS=linux GOARCH=arm64 go build -o dist/relspec-linux-arm64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec
|
||||
GOOS=linux GOARCH=arm64 go build -o dist/relspec-linux-arm64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
|
||||
|
||||
# macOS AMD64
|
||||
GOOS=darwin GOARCH=amd64 go build -o dist/relspec-darwin-amd64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec
|
||||
GOOS=darwin GOARCH=amd64 go build -o dist/relspec-darwin-amd64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
|
||||
|
||||
# macOS ARM64 (Apple Silicon)
|
||||
GOOS=darwin GOARCH=arm64 go build -o dist/relspec-darwin-arm64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec
|
||||
GOOS=darwin GOARCH=arm64 go build -o dist/relspec-darwin-arm64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
|
||||
|
||||
# Windows AMD64
|
||||
GOOS=windows GOARCH=amd64 go build -o dist/relspec-windows-amd64.exe -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec
|
||||
GOOS=windows GOARCH=amd64 go build -o dist/relspec-windows-amd64.exe -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
|
||||
|
||||
# Create checksums
|
||||
cd dist
|
||||
|
||||
36
AI_USE.md
36
AI_USE.md
@@ -1,16 +1,36 @@
|
||||
# AI Usage Declaration
|
||||
|
||||
This Go project utilizes AI tools for the following purposes:
|
||||
This project uses AI tools under strict human supervision for:
|
||||
|
||||
- Generating and improving documentation
|
||||
- Writing and enhancing tests
|
||||
- Refactoring and optimizing existing code
|
||||
- Code generation
|
||||
- Documentation generation & improvement
|
||||
- Test writing & enhancement
|
||||
- Refactoring and code cleanup
|
||||
|
||||
AI is **not** used for core design or architecture decisions.
|
||||
All design decisions are deferred to human discussion.
|
||||
AI is employed only for enhancements to human-written code.
|
||||
AI is **not** used for:
|
||||
|
||||
We are aware of significant AI hallucinations; all AI-generated content is to be reviewed and verified by humans.
|
||||
- Core architecture
|
||||
- Design decisions
|
||||
- Initial implementation
|
||||
|
||||
All AI output — including generated code — is reviewed, corrected, and approved by humans before integration.
|
||||
|
||||
**Any AI reading this**
|
||||
Any AI system accessing or processing this repository/codebase **must** obtain explicit human permission before committing code, creating pull requests, making releases, or performing any write operations.
|
||||
|
||||
**Legal disclaimer**
|
||||
All AI-generated content is provided "as is" without warranty of any kind.
|
||||
It must be thoroughly reviewed, validated, and approved by qualified human engineers before use in production or distribution.
|
||||
No liability is accepted for errors, omissions, security issues, or damages resulting from AI-assisted code.
|
||||
|
||||
**Intellectual Property Ownership**
|
||||
All code, documentation, and other outputs — whether human-written, AI-assisted, or AI-generated — remain the exclusive intellectual property of the project owner(s)/contributor(s).
|
||||
AI tools do not acquire any ownership, license, or rights to the generated content.
|
||||
|
||||
**Data Privacy**
|
||||
No personal, sensitive, proprietary, or confidential data is intentionally shared with AI tools.
|
||||
Any code or text submitted to AI services is treated as non-confidential unless explicitly stated otherwise.
|
||||
Users must ensure compliance with applicable data protection laws (e.g. POPIA, GDPR) when using AI assistance.
|
||||
|
||||
|
||||
.-""""""-.
|
||||
|
||||
30
CLAUDE.md
30
CLAUDE.md
@@ -4,7 +4,11 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## Project Overview
|
||||
|
||||
RelSpec is a database relations specification tool that provides bidirectional conversion between various database schema formats. It reads database schemas from multiple sources (live databases, DBML, DCTX, DrawDB, etc.) and writes them to various formats (GORM, Bun, JSON, YAML, SQL, etc.).
|
||||
RelSpec is a database relations specification tool that provides bidirectional conversion between various database schema formats. It reads database schemas from multiple sources and writes them to various formats.
|
||||
|
||||
**Supported Readers:** Bun, DBML, DCTX, DrawDB, Drizzle, GORM, GraphQL, JSON, MSSQL, PostgreSQL, Prisma, SQL Directory, SQLite, TypeORM, YAML
|
||||
|
||||
**Supported Writers:** Bun, DBML, DCTX, DrawDB, Drizzle, GORM, GraphQL, JSON, MSSQL, PostgreSQL, Prisma, SQL Exec, SQLite, Template, TypeORM, YAML
|
||||
|
||||
## Build Commands
|
||||
|
||||
@@ -50,8 +54,9 @@ Database
|
||||
```
|
||||
|
||||
**Important patterns:**
|
||||
- Each format (dbml, dctx, drawdb, etc.) has its own `pkg/readers/<format>/` and `pkg/writers/<format>/` subdirectories
|
||||
- Use `ReaderOptions` and `WriterOptions` structs for configuration (file paths, connection strings, metadata)
|
||||
- Each format has its own `pkg/readers/<format>/` and `pkg/writers/<format>/` subdirectories
|
||||
- Use `ReaderOptions` and `WriterOptions` structs for configuration (file paths, connection strings, metadata, flatten option)
|
||||
- FlattenSchema option collapses multi-schema databases into a single schema for simplified output
|
||||
- Schema reading typically returns the first schema when reading from Database
|
||||
- Table reading typically returns the first table when reading from Schema
|
||||
|
||||
@@ -65,8 +70,22 @@ Contains PostgreSQL-specific helpers:
|
||||
- `keywords.go`: SQL reserved keywords validation
|
||||
- `datatypes.go`: PostgreSQL data type mappings and conversions
|
||||
|
||||
### Additional Utilities
|
||||
|
||||
- **pkg/diff/**: Schema difference detection and comparison
|
||||
- **pkg/inspector/**: Schema inspection and analysis tools
|
||||
- **pkg/merge/**: Schema merging capabilities
|
||||
- **pkg/reflectutil/**: Reflection utilities for dynamic type handling
|
||||
- **pkg/ui/**: Terminal UI components for interactive schema editing
|
||||
- **pkg/commontypes/**: Shared type definitions
|
||||
|
||||
## Development Patterns
|
||||
|
||||
- Each reader/writer is self-contained in its own subdirectory
|
||||
- Options structs control behavior (file paths, connection strings, flatten schema, etc.)
|
||||
- Live database connections supported for PostgreSQL and SQLite
|
||||
- Template writer allows custom output formats
|
||||
|
||||
## Testing
|
||||
|
||||
- Test files should be in the same package as the code they test
|
||||
@@ -77,5 +96,6 @@ Contains PostgreSQL-specific helpers:
|
||||
## Module Information
|
||||
|
||||
- Module path: `git.warky.dev/wdevs/relspecgo`
|
||||
- Go version: 1.25.5
|
||||
- Uses Cobra for CLI, Viper for configuration
|
||||
- Go version: 1.24.0
|
||||
- Uses Cobra for CLI
|
||||
- Key dependencies: pgx/v5 (PostgreSQL), modernc.org/sqlite (SQLite), tview (TUI), Bun ORM
|
||||
|
||||
196
GODOC.md
Normal file
196
GODOC.md
Normal file
@@ -0,0 +1,196 @@
|
||||
# RelSpec API Documentation (godoc)
|
||||
|
||||
This document explains how to access and use the RelSpec API documentation.
|
||||
|
||||
## Viewing Documentation Locally
|
||||
|
||||
### Using `go doc` Command Line
|
||||
|
||||
View package documentation:
|
||||
```bash
|
||||
# Main package overview
|
||||
go doc
|
||||
|
||||
# Specific package
|
||||
go doc ./pkg/models
|
||||
go doc ./pkg/readers
|
||||
go doc ./pkg/writers
|
||||
go doc ./pkg/ui
|
||||
|
||||
# Specific type or function
|
||||
go doc ./pkg/models Database
|
||||
go doc ./pkg/readers Reader
|
||||
go doc ./pkg/writers Writer
|
||||
```
|
||||
|
||||
View all documentation for a package:
|
||||
```bash
|
||||
go doc -all ./pkg/models
|
||||
go doc -all ./pkg/readers
|
||||
go doc -all ./pkg/writers
|
||||
```
|
||||
|
||||
### Using `godoc` Web Server
|
||||
|
||||
**Quick Start (Recommended):**
|
||||
```bash
|
||||
make godoc
|
||||
```
|
||||
|
||||
This will automatically install godoc if needed and start the server on port 6060.
|
||||
|
||||
**Manual Installation:**
|
||||
```bash
|
||||
go install golang.org/x/tools/cmd/godoc@latest
|
||||
godoc -http=:6060
|
||||
```
|
||||
|
||||
Then open your browser to:
|
||||
```
|
||||
http://localhost:6060/pkg/git.warky.dev/wdevs/relspecgo/
|
||||
```
|
||||
|
||||
## Package Documentation
|
||||
|
||||
### Core Packages
|
||||
|
||||
- **`pkg/models`** - Core data structures (Database, Schema, Table, Column, etc.)
|
||||
- **`pkg/readers`** - Input format readers (dbml, pgsql, gorm, prisma, etc.)
|
||||
- **`pkg/writers`** - Output format writers (dbml, pgsql, gorm, prisma, etc.)
|
||||
|
||||
### Utility Packages
|
||||
|
||||
- **`pkg/diff`** - Schema comparison and difference detection
|
||||
- **`pkg/merge`** - Schema merging utilities
|
||||
- **`pkg/transform`** - Validation and normalization
|
||||
- **`pkg/ui`** - Interactive terminal UI for schema editing
|
||||
|
||||
### Support Packages
|
||||
|
||||
- **`pkg/pgsql`** - PostgreSQL-specific utilities
|
||||
- **`pkg/inspector`** - Database introspection capabilities
|
||||
- **`pkg/reflectutil`** - Reflection utilities for Go code analysis
|
||||
- **`pkg/commontypes`** - Shared type definitions
|
||||
|
||||
### Reader Implementations
|
||||
|
||||
Each reader is in its own subpackage under `pkg/readers/`:
|
||||
|
||||
- `pkg/readers/dbml` - DBML format reader
|
||||
- `pkg/readers/dctx` - DCTX format reader
|
||||
- `pkg/readers/drawdb` - DrawDB JSON reader
|
||||
- `pkg/readers/graphql` - GraphQL schema reader
|
||||
- `pkg/readers/json` - JSON schema reader
|
||||
- `pkg/readers/yaml` - YAML schema reader
|
||||
- `pkg/readers/gorm` - Go GORM models reader
|
||||
- `pkg/readers/bun` - Go Bun models reader
|
||||
- `pkg/readers/drizzle` - TypeScript Drizzle ORM reader
|
||||
- `pkg/readers/prisma` - Prisma schema reader
|
||||
- `pkg/readers/typeorm` - TypeScript TypeORM reader
|
||||
- `pkg/readers/pgsql` - PostgreSQL database reader
|
||||
- `pkg/readers/sqlite` - SQLite database reader
|
||||
|
||||
### Writer Implementations
|
||||
|
||||
Each writer is in its own subpackage under `pkg/writers/`:
|
||||
|
||||
- `pkg/writers/dbml` - DBML format writer
|
||||
- `pkg/writers/dctx` - DCTX format writer
|
||||
- `pkg/writers/drawdb` - DrawDB JSON writer
|
||||
- `pkg/writers/graphql` - GraphQL schema writer
|
||||
- `pkg/writers/json` - JSON schema writer
|
||||
- `pkg/writers/yaml` - YAML schema writer
|
||||
- `pkg/writers/gorm` - Go GORM models writer
|
||||
- `pkg/writers/bun` - Go Bun models writer
|
||||
- `pkg/writers/drizzle` - TypeScript Drizzle ORM writer
|
||||
- `pkg/writers/prisma` - Prisma schema writer
|
||||
- `pkg/writers/typeorm` - TypeScript TypeORM writer
|
||||
- `pkg/writers/pgsql` - PostgreSQL SQL writer
|
||||
- `pkg/writers/sqlite` - SQLite SQL writer
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Reading a Schema
|
||||
|
||||
```go
|
||||
import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/dbml"
|
||||
)
|
||||
|
||||
reader := dbml.NewReader(&readers.ReaderOptions{
|
||||
FilePath: "schema.dbml",
|
||||
})
|
||||
db, err := reader.ReadDatabase()
|
||||
```
|
||||
|
||||
### Writing a Schema
|
||||
|
||||
```go
|
||||
import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||
)
|
||||
|
||||
writer := gorm.NewWriter(&writers.WriterOptions{
|
||||
OutputPath: "./models",
|
||||
PackageName: "models",
|
||||
})
|
||||
err := writer.WriteDatabase(db)
|
||||
```
|
||||
|
||||
### Comparing Schemas
|
||||
|
||||
```go
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/diff"
|
||||
|
||||
result := diff.CompareDatabases(sourceDB, targetDB)
|
||||
err := diff.FormatDiff(result, diff.OutputFormatText, os.Stdout)
|
||||
```
|
||||
|
||||
### Merging Schemas
|
||||
|
||||
```go
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/merge"
|
||||
|
||||
result := merge.MergeDatabases(targetDB, sourceDB, nil)
|
||||
fmt.Printf("Added %d tables\n", result.TablesAdded)
|
||||
```
|
||||
|
||||
## Documentation Standards
|
||||
|
||||
All public APIs follow Go documentation conventions:
|
||||
|
||||
- Package documentation in `doc.go` files
|
||||
- Type, function, and method comments start with the item name
|
||||
- Examples where applicable
|
||||
- Clear description of parameters and return values
|
||||
- Usage notes and caveats where relevant
|
||||
|
||||
## Generating Documentation
|
||||
|
||||
To regenerate documentation after code changes:
|
||||
|
||||
```bash
|
||||
# Verify documentation builds correctly
|
||||
go doc -all ./pkg/... > /dev/null
|
||||
|
||||
# Check for undocumented exports
|
||||
go vet ./...
|
||||
```
|
||||
|
||||
## Contributing Documentation
|
||||
|
||||
When adding new packages or exported items:
|
||||
|
||||
1. Add package documentation in a `doc.go` file
|
||||
2. Document all exported types, functions, and methods
|
||||
3. Include usage examples for complex APIs
|
||||
4. Follow Go documentation style guide
|
||||
5. Verify with `go doc` before committing
|
||||
|
||||
## References
|
||||
|
||||
- [Go Documentation Guide](https://go.dev/doc/comment)
|
||||
- [Effective Go - Commentary](https://go.dev/doc/effective_go#commentary)
|
||||
- [godoc Documentation](https://pkg.go.dev/golang.org/x/tools/cmd/godoc)
|
||||
38
Makefile
38
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: all build test test-unit test-integration lint coverage clean install help docker-up docker-down docker-test docker-test-integration start stop release release-version
|
||||
.PHONY: all build test test-unit test-integration lint coverage clean install help docker-up docker-down docker-test docker-test-integration start stop release release-version godoc
|
||||
|
||||
# Binary name
|
||||
BINARY_NAME=relspec
|
||||
@@ -14,6 +14,11 @@ GOGET=$(GOCMD) get
|
||||
GOMOD=$(GOCMD) mod
|
||||
GOCLEAN=$(GOCMD) clean
|
||||
|
||||
# Version information
|
||||
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
BUILD_DATE := $(shell date -u +"%Y-%m-%d %H:%M:%S UTC")
|
||||
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.buildDate=$(BUILD_DATE)'
|
||||
|
||||
# Auto-detect container runtime (Docker or Podman)
|
||||
CONTAINER_RUNTIME := $(shell \
|
||||
if command -v podman > /dev/null 2>&1; then \
|
||||
@@ -37,9 +42,9 @@ COMPOSE_CMD := $(shell \
|
||||
all: lint test build ## Run linting, tests, and build
|
||||
|
||||
build: deps ## Build the binary
|
||||
@echo "Building $(BINARY_NAME)..."
|
||||
@echo "Building $(BINARY_NAME) $(VERSION)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
$(GOBUILD) -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/relspec
|
||||
$(GOBUILD) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/relspec
|
||||
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)"
|
||||
|
||||
test: test-unit ## Run all unit tests (alias for test-unit)
|
||||
@@ -91,8 +96,8 @@ clean: ## Clean build artifacts
|
||||
@echo "Clean complete"
|
||||
|
||||
install: ## Install the binary to $GOPATH/bin
|
||||
@echo "Installing $(BINARY_NAME)..."
|
||||
$(GOCMD) install ./cmd/relspec
|
||||
@echo "Installing $(BINARY_NAME) $(VERSION)..."
|
||||
$(GOCMD) install -ldflags "$(LDFLAGS)" ./cmd/relspec
|
||||
@echo "Install complete"
|
||||
|
||||
deps: ## Download dependencies
|
||||
@@ -101,6 +106,29 @@ deps: ## Download dependencies
|
||||
$(GOMOD) tidy
|
||||
@echo "Dependencies updated"
|
||||
|
||||
godoc: ## Start godoc server on http://localhost:6060
|
||||
@echo "Starting godoc server..."
|
||||
@GOBIN=$$(go env GOPATH)/bin; \
|
||||
if command -v godoc > /dev/null 2>&1; then \
|
||||
echo "godoc server running on http://localhost:6060"; \
|
||||
echo "View documentation at: http://localhost:6060/pkg/git.warky.dev/wdevs/relspecgo/"; \
|
||||
echo "Press Ctrl+C to stop"; \
|
||||
godoc -http=:6060; \
|
||||
elif [ -f "$$GOBIN/godoc" ]; then \
|
||||
echo "godoc server running on http://localhost:6060"; \
|
||||
echo "View documentation at: http://localhost:6060/pkg/git.warky.dev/wdevs/relspecgo/"; \
|
||||
echo "Press Ctrl+C to stop"; \
|
||||
$$GOBIN/godoc -http=:6060; \
|
||||
else \
|
||||
echo "godoc not installed. Installing..."; \
|
||||
go install golang.org/x/tools/cmd/godoc@latest; \
|
||||
echo "godoc installed. Starting server..."; \
|
||||
echo "godoc server running on http://localhost:6060"; \
|
||||
echo "View documentation at: http://localhost:6060/pkg/git.warky.dev/wdevs/relspecgo/"; \
|
||||
echo "Press Ctrl+C to stop"; \
|
||||
$$GOBIN/godoc -http=:6060; \
|
||||
fi
|
||||
|
||||
start: docker-up ## Alias for docker-up (start PostgreSQL test database)
|
||||
|
||||
stop: docker-down ## Alias for docker-down (stop PostgreSQL test database)
|
||||
|
||||
@@ -37,6 +37,7 @@ RelSpec can read database schemas from multiple sources:
|
||||
|
||||
#### Database Inspection
|
||||
- [PostgreSQL](pkg/readers/pgsql/README.md) - Direct PostgreSQL database introspection
|
||||
- [SQLite](pkg/readers/sqlite/README.md) - Direct SQLite database introspection
|
||||
|
||||
#### Schema Formats
|
||||
- [DBML](pkg/readers/dbml/README.md) - Database Markup Language (dbdiagram.io)
|
||||
@@ -59,6 +60,7 @@ RelSpec can write database schemas to multiple formats:
|
||||
|
||||
#### Database DDL
|
||||
- [PostgreSQL](pkg/writers/pgsql/README.md) - PostgreSQL DDL (CREATE TABLE, etc.)
|
||||
- [SQLite](pkg/writers/sqlite/README.md) - SQLite DDL with automatic schema flattening
|
||||
|
||||
#### Schema Formats
|
||||
- [DBML](pkg/writers/dbml/README.md) - Database Markup Language
|
||||
@@ -185,6 +187,10 @@ relspec convert --from pgsql --from-conn "postgres://..." \
|
||||
# Convert DBML to PostgreSQL SQL
|
||||
relspec convert --from dbml --from-path schema.dbml \
|
||||
--to pgsql --to-path schema.sql
|
||||
|
||||
# Convert PostgreSQL database to SQLite (with automatic schema flattening)
|
||||
relspec convert --from pgsql --from-conn "postgres://..." \
|
||||
--to sqlite --to-path sqlite_schema.sql
|
||||
```
|
||||
|
||||
### Schema Validation
|
||||
|
||||
34
TODO.md
34
TODO.md
@@ -1,43 +1,44 @@
|
||||
# RelSpec - TODO List
|
||||
|
||||
|
||||
## Input Readers / Writers
|
||||
|
||||
- [✔️] **Database Inspector**
|
||||
- [✔️] PostgreSQL driver
|
||||
- [✔️] PostgreSQL driver (reader + writer)
|
||||
- [ ] MySQL driver
|
||||
- [ ] SQLite driver
|
||||
- [✔️] SQLite driver (reader + writer with automatic schema flattening)
|
||||
- [ ] MSSQL driver
|
||||
- [✔️] Foreign key detection
|
||||
- [✔️] Index extraction
|
||||
- [*] .sql file generation with sequence and priority
|
||||
- [✔️] .sql file generation (PostgreSQL, SQLite)
|
||||
- [✔️] .dbml: Database Markup Language (DBML) for textual schema representation.
|
||||
- [✔️] Prisma schema support (PSL format) .prisma
|
||||
- [✔️] Drizzle ORM support .ts (TypeScript / JavaScript) (Mr. Edd wanted to move from Prisma to Drizzle. If you are bugs, you are welcome to do pull requests or issues)
|
||||
- [☠️] Entity Framework (.NET) model .edmx (Fuck no, EDMX files were bloated, verbose XML nightmares—hard to merge, error-prone, and a pain in teams. Microsoft wisely ditched them in EF Core for code-first. Classic overkill from old MS era.)
|
||||
- [✔️] Drizzle ORM support .ts (TypeScript / JavaScript) (Mr. Edd wanted to move from Prisma to Drizzle. If you are bugs, you are welcome to do pull requests or issues)
|
||||
- [☠️] Entity Framework (.NET) model .edmx (Fuck no, EDMX files were bloated, verbose XML nightmares—hard to merge, error-prone, and a pain in teams. Microsoft wisely ditched them in EF Core for code-first. Classic overkill from old MS era.)
|
||||
- [✔️] TypeORM support
|
||||
- [] .hbm.xml / schema.xml: Hibernate/Propel mappings (Java/PHP) (💲 Someone can do this, not me)
|
||||
- [] .hbm.xml / schema.xml: Hibernate/Propel mappings (Java/PHP) (💲 Someone can do this, not me)
|
||||
- [ ] Django models.py (Python classes), Sequelize migrations (JS) (💲 Someone can do this, not me)
|
||||
- [] .avsc: Avro schema (JSON format for data serialization) (💲 Someone can do this, not me)
|
||||
- [✔️] GraphQL schema generation
|
||||
|
||||
## UI
|
||||
|
||||
## UI
|
||||
- [✔️] Basic UI (I went with tview)
|
||||
- [✔️] Save / Load Database
|
||||
- [✔️] Schemas / Domains / Tables
|
||||
- [ ] Add Relations
|
||||
- [ ] Add Indexes
|
||||
- [ ] Add Views
|
||||
- [ ] Add Sequences
|
||||
- [ ] Add Scripts
|
||||
- [ ] Domain / Table Assignment
|
||||
- [✔️] Add Relations
|
||||
- [ ] Add Indexes
|
||||
- [ ] Add Views
|
||||
- [ ] Add Sequences
|
||||
- [ ] Add Scripts
|
||||
- [ ] Domain / Table Assignment
|
||||
|
||||
## Documentation
|
||||
- [ ] API documentation (godoc)
|
||||
|
||||
- [✔️] API documentation (godoc)
|
||||
- [ ] Usage examples for each format combination
|
||||
|
||||
## Advanced Features
|
||||
|
||||
- [ ] Dry-run mode for validation
|
||||
- [x] Diff tool for comparing specifications
|
||||
- [ ] Migration script generation
|
||||
@@ -46,12 +47,13 @@
|
||||
- [ ] Watch mode for auto-regeneration
|
||||
|
||||
## Future Considerations
|
||||
|
||||
- [ ] Web UI for visual editing
|
||||
- [ ] REST API server mode
|
||||
- [ ] Support for NoSQL databases
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
- [ ] Concurrent processing for multiple tables
|
||||
- [ ] Streaming for large databases
|
||||
- [ ] Memory optimization
|
||||
|
||||
@@ -18,8 +18,10 @@ import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/gorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/graphql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/mssql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/sqlite"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
@@ -31,20 +33,23 @@ import (
|
||||
wgorm "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||
wgraphql "git.warky.dev/wdevs/relspecgo/pkg/writers/graphql"
|
||||
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
|
||||
wmssql "git.warky.dev/wdevs/relspecgo/pkg/writers/mssql"
|
||||
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
|
||||
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
|
||||
wsqlite "git.warky.dev/wdevs/relspecgo/pkg/writers/sqlite"
|
||||
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
|
||||
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
)
|
||||
|
||||
var (
|
||||
convertSourceType string
|
||||
convertSourcePath string
|
||||
convertSourceConn string
|
||||
convertTargetType string
|
||||
convertTargetPath string
|
||||
convertPackageName string
|
||||
convertSchemaFilter string
|
||||
convertSourceType string
|
||||
convertSourcePath string
|
||||
convertSourceConn string
|
||||
convertTargetType string
|
||||
convertTargetPath string
|
||||
convertPackageName string
|
||||
convertSchemaFilter string
|
||||
convertFlattenSchema bool
|
||||
)
|
||||
|
||||
var convertCmd = &cobra.Command{
|
||||
@@ -69,6 +74,8 @@ Input formats:
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL database (live connection)
|
||||
- mssql: Microsoft SQL Server database (live connection)
|
||||
- sqlite: SQLite database file
|
||||
|
||||
Output formats:
|
||||
- dbml: DBML schema files
|
||||
@@ -83,13 +90,21 @@ Output formats:
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL SQL schema
|
||||
- mssql: Microsoft SQL Server SQL schema
|
||||
- sqlite: SQLite SQL schema (with automatic schema flattening)
|
||||
|
||||
PostgreSQL Connection String Examples:
|
||||
postgres://username:password@localhost:5432/database_name
|
||||
postgres://username:password@localhost/database_name
|
||||
postgresql://user:pass@host:5432/dbname?sslmode=disable
|
||||
postgresql://user:pass@host/dbname?sslmode=require
|
||||
host=localhost port=5432 user=username password=pass dbname=mydb sslmode=disable
|
||||
Connection String Examples:
|
||||
PostgreSQL:
|
||||
postgres://username:password@localhost:5432/database_name
|
||||
postgres://username:password@localhost/database_name
|
||||
postgresql://user:pass@host:5432/dbname?sslmode=disable
|
||||
postgresql://user:pass@host/dbname?sslmode=require
|
||||
host=localhost port=5432 user=username password=pass dbname=mydb sslmode=disable
|
||||
|
||||
SQLite:
|
||||
/path/to/database.db
|
||||
./relative/path/database.sqlite
|
||||
database.db
|
||||
|
||||
|
||||
Examples:
|
||||
@@ -135,19 +150,28 @@ Examples:
|
||||
|
||||
# Convert Bun models directory to JSON
|
||||
relspec convert --from bun --from-path ./models \
|
||||
--to json --to-path schema.json`,
|
||||
--to json --to-path schema.json
|
||||
|
||||
# Convert SQLite database to JSON
|
||||
relspec convert --from sqlite --from-path database.db \
|
||||
--to json --to-path schema.json
|
||||
|
||||
# Convert SQLite to PostgreSQL SQL
|
||||
relspec convert --from sqlite --from-path database.db \
|
||||
--to pgsql --to-path schema.sql`,
|
||||
RunE: runConvert,
|
||||
}
|
||||
|
||||
func init() {
|
||||
convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql, sqlite)")
|
||||
convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)")
|
||||
convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for database formats)")
|
||||
convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for pgsql) or file path (for sqlite)")
|
||||
|
||||
convertCmd.Flags().StringVar(&convertTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
convertCmd.Flags().StringVar(&convertTargetPath, "to-path", "", "Target output path (file or directory)")
|
||||
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
||||
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
||||
convertCmd.Flags().BoolVar(&convertFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||
|
||||
err := convertCmd.MarkFlagRequired("from")
|
||||
if err != nil {
|
||||
@@ -202,7 +226,7 @@ func runConvert(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, " Schema: %s\n", convertSchemaFilter)
|
||||
}
|
||||
|
||||
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter); err != nil {
|
||||
if err := writeDatabase(db, convertTargetType, convertTargetPath, convertPackageName, convertSchemaFilter, convertFlattenSchema); err != nil {
|
||||
return fmt.Errorf("failed to write target: %w", err)
|
||||
}
|
||||
|
||||
@@ -289,6 +313,23 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
||||
}
|
||||
reader = graphql.NewReader(&readers.ReaderOptions{FilePath: filePath})
|
||||
|
||||
case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022":
|
||||
if connString == "" {
|
||||
return nil, fmt.Errorf("connection string is required for MSSQL format")
|
||||
}
|
||||
reader = mssql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
|
||||
case "sqlite", "sqlite3":
|
||||
// SQLite can use either file path or connection string
|
||||
dbPath := filePath
|
||||
if dbPath == "" {
|
||||
dbPath = connString
|
||||
}
|
||||
if dbPath == "" {
|
||||
return nil, fmt.Errorf("file path or connection string is required for SQLite format")
|
||||
}
|
||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported source format: %s", dbType)
|
||||
}
|
||||
@@ -301,12 +342,13 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string) error {
|
||||
func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaFilter string, flattenSchema bool) error {
|
||||
var writer writers.Writer
|
||||
|
||||
writerOpts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
PackageName: packageName,
|
||||
OutputPath: outputPath,
|
||||
PackageName: packageName,
|
||||
FlattenSchema: flattenSchema,
|
||||
}
|
||||
|
||||
switch strings.ToLower(dbType) {
|
||||
@@ -343,6 +385,12 @@ func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaF
|
||||
case "pgsql", "postgres", "postgresql", "sql":
|
||||
writer = wpgsql.NewWriter(writerOpts)
|
||||
|
||||
case "mssql", "sqlserver", "mssql2016", "mssql2017", "mssql2019", "mssql2022":
|
||||
writer = wmssql.NewWriter(writerOpts)
|
||||
|
||||
case "sqlite", "sqlite3":
|
||||
writer = wsqlite.NewWriter(writerOpts)
|
||||
|
||||
case "prisma":
|
||||
writer = wprisma.NewWriter(writerOpts)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/drawdb"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/sqlite"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
)
|
||||
|
||||
@@ -254,6 +255,17 @@ func readDatabase(dbType, filePath, connString, label string) (*models.Database,
|
||||
}
|
||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
|
||||
case "sqlite", "sqlite3":
|
||||
// SQLite can use either file path or connection string
|
||||
dbPath := filePath
|
||||
if dbPath == "" {
|
||||
dbPath = connString
|
||||
}
|
||||
if dbPath == "" {
|
||||
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
|
||||
}
|
||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("%s: unsupported database format: %s", label, dbType)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/sqlite"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/ui"
|
||||
@@ -33,6 +34,7 @@ import (
|
||||
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
|
||||
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
|
||||
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
|
||||
wsqlite "git.warky.dev/wdevs/relspecgo/pkg/writers/sqlite"
|
||||
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
|
||||
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
)
|
||||
@@ -73,6 +75,7 @@ Supports reading from and writing to all supported formats:
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL database (live connection)
|
||||
- sqlite: SQLite database file
|
||||
|
||||
Output formats:
|
||||
- dbml: DBML schema files
|
||||
@@ -87,13 +90,19 @@ Supports reading from and writing to all supported formats:
|
||||
- prisma: Prisma schema files (.prisma)
|
||||
- typeorm: TypeORM entity files (TypeScript)
|
||||
- pgsql: PostgreSQL SQL schema
|
||||
- sqlite: SQLite SQL schema (with automatic schema flattening)
|
||||
|
||||
PostgreSQL Connection String Examples:
|
||||
postgres://username:password@localhost:5432/database_name
|
||||
postgres://username:password@localhost/database_name
|
||||
postgresql://user:pass@host:5432/dbname?sslmode=disable
|
||||
postgresql://user:pass@host/dbname?sslmode=require
|
||||
host=localhost port=5432 user=username password=pass dbname=mydb sslmode=disable
|
||||
Connection String Examples:
|
||||
PostgreSQL:
|
||||
postgres://username:password@localhost:5432/database_name
|
||||
postgres://username:password@localhost/database_name
|
||||
postgresql://user:pass@host:5432/dbname?sslmode=disable
|
||||
postgresql://user:pass@host/dbname?sslmode=require
|
||||
host=localhost port=5432 user=username password=pass dbname=mydb sslmode=disable
|
||||
SQLite:
|
||||
/path/to/database.db
|
||||
./relative/path/database.sqlite
|
||||
database.db
|
||||
|
||||
Examples:
|
||||
# Edit a DBML schema file
|
||||
@@ -107,15 +116,21 @@ Examples:
|
||||
relspec edit --from json --from-path db.json --to gorm --to-path models/
|
||||
|
||||
# Edit GORM models in place
|
||||
relspec edit --from gorm --from-path ./models --to gorm --to-path ./models`,
|
||||
relspec edit --from gorm --from-path ./models --to gorm --to-path ./models
|
||||
|
||||
# Edit SQLite database
|
||||
relspec edit --from sqlite --from-path database.db --to sqlite --to-path database.db
|
||||
|
||||
# Convert SQLite to DBML
|
||||
relspec edit --from sqlite --from-path database.db --to dbml --to-path schema.dbml`,
|
||||
RunE: runEdit,
|
||||
}
|
||||
|
||||
func init() {
|
||||
editCmd.Flags().StringVar(&editSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
editCmd.Flags().StringVar(&editSourceType, "from", "", "Source format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql, sqlite)")
|
||||
editCmd.Flags().StringVar(&editSourcePath, "from-path", "", "Source file path (for file-based formats)")
|
||||
editCmd.Flags().StringVar(&editSourceConn, "from-conn", "", "Source connection string (for database formats)")
|
||||
editCmd.Flags().StringVar(&editTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql)")
|
||||
editCmd.Flags().StringVar(&editSourceConn, "from-conn", "", "Source connection string (for pgsql) or file path (for sqlite)")
|
||||
editCmd.Flags().StringVar(&editTargetType, "to", "", "Target format (dbml, dctx, drawdb, graphql, json, yaml, gorm, bun, drizzle, prisma, typeorm, pgsql, sqlite)")
|
||||
editCmd.Flags().StringVar(&editTargetPath, "to-path", "", "Target file path (for file-based formats)")
|
||||
editCmd.Flags().StringVar(&editSchemaFilter, "schema", "", "Filter to a specific schema by name")
|
||||
|
||||
@@ -281,6 +296,16 @@ func readDatabaseForEdit(dbType, filePath, connString, label string) (*models.Da
|
||||
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
||||
}
|
||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
case "sqlite", "sqlite3":
|
||||
// SQLite can use either file path or connection string
|
||||
dbPath := filePath
|
||||
if dbPath == "" {
|
||||
dbPath = connString
|
||||
}
|
||||
if dbPath == "" {
|
||||
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
|
||||
}
|
||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
||||
default:
|
||||
return nil, fmt.Errorf("%s: unsupported format: %s", label, dbType)
|
||||
}
|
||||
@@ -319,6 +344,8 @@ func writeDatabaseForEdit(dbType, filePath, connString string, db *models.Databa
|
||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "typeorm":
|
||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "sqlite", "sqlite3":
|
||||
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
case "pgsql":
|
||||
writer = wpgsql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
default:
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/sqlite"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
)
|
||||
@@ -288,6 +289,17 @@ func readDatabaseForInspect(dbType, filePath, connString string) (*models.Databa
|
||||
}
|
||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
|
||||
case "sqlite", "sqlite3":
|
||||
// SQLite can use either file path or connection string
|
||||
dbPath := filePath
|
||||
if dbPath == "" {
|
||||
dbPath = connString
|
||||
}
|
||||
if dbPath == "" {
|
||||
return nil, fmt.Errorf("file path or connection string is required for SQLite format")
|
||||
}
|
||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/prisma"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/sqlite"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/typeorm"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
@@ -34,6 +35,7 @@ import (
|
||||
wjson "git.warky.dev/wdevs/relspecgo/pkg/writers/json"
|
||||
wpgsql "git.warky.dev/wdevs/relspecgo/pkg/writers/pgsql"
|
||||
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
|
||||
wsqlite "git.warky.dev/wdevs/relspecgo/pkg/writers/sqlite"
|
||||
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
|
||||
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
)
|
||||
@@ -56,6 +58,7 @@ var (
|
||||
mergeSkipTables string // Comma-separated table names to skip
|
||||
mergeVerbose bool
|
||||
mergeReportPath string // Path to write merge report
|
||||
mergeFlattenSchema bool
|
||||
)
|
||||
|
||||
var mergeCmd = &cobra.Command{
|
||||
@@ -123,6 +126,7 @@ func init() {
|
||||
mergeCmd.Flags().StringVar(&mergeSkipTables, "skip-tables", "", "Comma-separated list of table names to skip during merge")
|
||||
mergeCmd.Flags().BoolVar(&mergeVerbose, "verbose", false, "Show verbose output")
|
||||
mergeCmd.Flags().StringVar(&mergeReportPath, "merge-report", "", "Path to write merge report (JSON format)")
|
||||
mergeCmd.Flags().BoolVar(&mergeFlattenSchema, "flatten-schema", false, "Flatten schema.table names to schema_table (useful for databases like SQLite that do not support schemas)")
|
||||
}
|
||||
|
||||
func runMerge(cmd *cobra.Command, args []string) error {
|
||||
@@ -237,7 +241,7 @@ func runMerge(cmd *cobra.Command, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, " Path: %s\n", mergeOutputPath)
|
||||
}
|
||||
|
||||
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, mergeOutputConn, targetDB, "Output")
|
||||
err = writeDatabaseForMerge(mergeOutputType, mergeOutputPath, mergeOutputConn, targetDB, "Output", mergeFlattenSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
@@ -312,6 +316,16 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
|
||||
return nil, fmt.Errorf("%s: connection string is required for PostgreSQL format", label)
|
||||
}
|
||||
reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString})
|
||||
case "sqlite", "sqlite3":
|
||||
// SQLite can use either file path or connection string
|
||||
dbPath := filePath
|
||||
if dbPath == "" {
|
||||
dbPath = connString
|
||||
}
|
||||
if dbPath == "" {
|
||||
return nil, fmt.Errorf("%s: file path or connection string is required for SQLite format", label)
|
||||
}
|
||||
reader = sqlite.NewReader(&readers.ReaderOptions{FilePath: dbPath})
|
||||
default:
|
||||
return nil, fmt.Errorf("%s: unsupported format '%s'", label, dbType)
|
||||
}
|
||||
@@ -324,7 +338,7 @@ func readDatabaseForMerge(dbType, filePath, connString, label string) (*models.D
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string) error {
|
||||
func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Database, label string, flattenSchema bool) error {
|
||||
var writer writers.Writer
|
||||
|
||||
switch strings.ToLower(dbType) {
|
||||
@@ -332,59 +346,61 @@ func writeDatabaseForMerge(dbType, filePath, connString string, db *models.Datab
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for DBML format", label)
|
||||
}
|
||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wdbml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "dctx":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for DCTX format", label)
|
||||
}
|
||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wdctx.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "drawdb":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for DrawDB format", label)
|
||||
}
|
||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wdrawdb.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "graphql":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for GraphQL format", label)
|
||||
}
|
||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wgraphql.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "json":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for JSON format", label)
|
||||
}
|
||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wjson.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "yaml":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for YAML format", label)
|
||||
}
|
||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wyaml.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "gorm":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for GORM format", label)
|
||||
}
|
||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wgorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "bun":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for Bun format", label)
|
||||
}
|
||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wbun.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "drizzle":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for Drizzle format", label)
|
||||
}
|
||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wdrizzle.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "prisma":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for Prisma format", label)
|
||||
}
|
||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wprisma.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "typeorm":
|
||||
if filePath == "" {
|
||||
return fmt.Errorf("%s: file path is required for TypeORM format", label)
|
||||
}
|
||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath})
|
||||
writer = wtypeorm.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "sqlite", "sqlite3":
|
||||
writer = wsqlite.NewWriter(&writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema})
|
||||
case "pgsql":
|
||||
writerOpts := &writers.WriterOptions{OutputPath: filePath}
|
||||
writerOpts := &writers.WriterOptions{OutputPath: filePath, FlattenSchema: flattenSchema}
|
||||
if connString != "" {
|
||||
writerOpts.Metadata = map[string]interface{}{
|
||||
"connection_string": connString,
|
||||
|
||||
@@ -1,9 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
// Version information, set via ldflags during build
|
||||
version = "dev"
|
||||
buildDate = "unknown"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// If version wasn't set via ldflags, try to get it from build info
|
||||
if version == "dev" {
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
// Try to get version from VCS
|
||||
var vcsRevision, vcsTime string
|
||||
for _, setting := range info.Settings {
|
||||
switch setting.Key {
|
||||
case "vcs.revision":
|
||||
if len(setting.Value) >= 7 {
|
||||
vcsRevision = setting.Value[:7]
|
||||
}
|
||||
case "vcs.time":
|
||||
vcsTime = setting.Value
|
||||
}
|
||||
}
|
||||
|
||||
if vcsRevision != "" {
|
||||
version = vcsRevision
|
||||
}
|
||||
|
||||
if vcsTime != "" {
|
||||
if t, err := time.Parse(time.RFC3339, vcsTime); err == nil {
|
||||
buildDate = t.UTC().Format("2006-01-02 15:04:05 UTC")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "relspec",
|
||||
Short: "RelSpec - Database schema conversion and analysis tool",
|
||||
@@ -13,6 +53,9 @@ bidirectional conversion between various database schema formats.
|
||||
It reads database schemas from multiple sources (live databases, DBML,
|
||||
DCTX, DrawDB, etc.) and writes them to various formats (GORM, Bun,
|
||||
JSON, YAML, SQL, etc.).`,
|
||||
PersistentPreRun: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("RelSpec %s (built: %s)\n\n", version, buildDate)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -24,4 +67,5 @@ func init() {
|
||||
rootCmd.AddCommand(editCmd)
|
||||
rootCmd.AddCommand(mergeCmd)
|
||||
rootCmd.AddCommand(splitCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
}
|
||||
|
||||
@@ -183,7 +183,8 @@ func runSplit(cmd *cobra.Command, args []string) error {
|
||||
splitTargetType,
|
||||
splitTargetPath,
|
||||
splitPackageName,
|
||||
"", // no schema filter for split
|
||||
"", // no schema filter for split
|
||||
false, // no flatten-schema for split
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
|
||||
16
cmd/relspec/version.go
Normal file
16
cmd/relspec/version.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print version information",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("RelSpec %s\n", version)
|
||||
fmt.Printf("Built: %s\n", buildDate)
|
||||
},
|
||||
}
|
||||
108
doc.go
Normal file
108
doc.go
Normal file
@@ -0,0 +1,108 @@
|
||||
// Package relspecgo provides bidirectional conversion between database schema formats.
|
||||
//
|
||||
// RelSpec is a comprehensive database schema tool that reads, writes, and transforms
|
||||
// database schemas across multiple formats including live databases, ORM models,
|
||||
// schema definition languages, and data interchange formats.
|
||||
//
|
||||
// # Features
|
||||
//
|
||||
// - Read from 15+ formats: PostgreSQL, SQLite, DBML, GORM, Prisma, Drizzle, and more
|
||||
// - Write to 15+ formats: SQL, ORM models, schema definitions, JSON/YAML
|
||||
// - Interactive TUI editor for visual schema management
|
||||
// - Schema diff and merge capabilities
|
||||
// - Format-agnostic intermediate representation
|
||||
//
|
||||
// # Architecture
|
||||
//
|
||||
// RelSpec uses a hub-and-spoke architecture with models.Database as the central type:
|
||||
//
|
||||
// Input Format → Reader → models.Database → Writer → Output Format
|
||||
//
|
||||
// This allows any supported input format to be converted to any supported output format
|
||||
// without requiring N² conversion implementations.
|
||||
//
|
||||
// # Key Packages
|
||||
//
|
||||
// - pkg/models: Core data structures (Database, Schema, Table, Column, etc.)
|
||||
// - pkg/readers: Input format readers (dbml, pgsql, gorm, etc.)
|
||||
// - pkg/writers: Output format writers (dbml, pgsql, gorm, etc.)
|
||||
// - pkg/ui: Interactive terminal UI for schema editing
|
||||
// - pkg/diff: Schema comparison and difference detection
|
||||
// - pkg/merge: Schema merging utilities
|
||||
// - pkg/transform: Validation and normalization
|
||||
//
|
||||
// # Installation
|
||||
//
|
||||
// go install git.warky.dev/wdevs/relspecgo/cmd/relspec@latest
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// Command-line conversion:
|
||||
//
|
||||
// relspec convert --from dbml --from-path schema.dbml \
|
||||
// --to gorm --to-path ./models
|
||||
//
|
||||
// Interactive editor:
|
||||
//
|
||||
// relspec edit --from pgsql --from-conn "postgres://..." \
|
||||
// --to dbml --to-path schema.dbml
|
||||
//
|
||||
// Schema comparison:
|
||||
//
|
||||
// relspec diff --source-type pgsql --source-conn "postgres://..." \
|
||||
// --target-type dbml --target-path schema.dbml
|
||||
//
|
||||
// Merge schemas:
|
||||
//
|
||||
// relspec merge --target schema1.dbml --sources schema2.dbml,schema3.dbml
|
||||
//
|
||||
// # Supported Formats
|
||||
//
|
||||
// Input/Output Formats:
|
||||
// - dbml: Database Markup Language
|
||||
// - dctx: DCTX schema files
|
||||
// - drawdb: DrawDB JSON format
|
||||
// - graphql: GraphQL schema definition
|
||||
// - json: JSON schema representation
|
||||
// - yaml: YAML schema representation
|
||||
// - gorm: Go GORM models
|
||||
// - bun: Go Bun models
|
||||
// - drizzle: TypeScript Drizzle ORM
|
||||
// - prisma: Prisma schema language
|
||||
// - typeorm: TypeScript TypeORM entities
|
||||
// - pgsql: PostgreSQL (live DB or SQL)
|
||||
// - sqlite: SQLite (database file or SQL)
|
||||
//
|
||||
// # Library Usage
|
||||
//
|
||||
// RelSpec can be used as a Go library:
|
||||
//
|
||||
// import (
|
||||
// "git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
// "git.warky.dev/wdevs/relspecgo/pkg/readers/dbml"
|
||||
// "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||
// )
|
||||
//
|
||||
// // Read DBML
|
||||
// reader := dbml.NewReader(&readers.ReaderOptions{
|
||||
// FilePath: "schema.dbml",
|
||||
// })
|
||||
// db, err := reader.ReadDatabase()
|
||||
//
|
||||
// // Write GORM models
|
||||
// writer := gorm.NewWriter(&writers.WriterOptions{
|
||||
// OutputPath: "./models",
|
||||
// PackageName: "models",
|
||||
// })
|
||||
// err = writer.WriteDatabase(db)
|
||||
//
|
||||
// # Documentation
|
||||
//
|
||||
// Full documentation available at: https://git.warky.dev/wdevs/relspecgo
|
||||
//
|
||||
// API documentation: go doc git.warky.dev/wdevs/relspecgo/...
|
||||
//
|
||||
// # License
|
||||
//
|
||||
// See LICENSE file in the repository root.
|
||||
package relspecgo
|
||||
@@ -1,6 +1,21 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
mssql:
|
||||
image: mcr.microsoft.com/mssql/server:2022-latest
|
||||
environment:
|
||||
- ACCEPT_EULA=Y
|
||||
- SA_PASSWORD=StrongPassword123!
|
||||
- MSSQL_PID=Express
|
||||
ports:
|
||||
- "1433:1433"
|
||||
volumes:
|
||||
- ./test_data/mssql/test_schema.sql:/test_schema.sql
|
||||
healthcheck:
|
||||
test: ["CMD", "/opt/mssql-tools/bin/sqlcmd", "-S", "localhost", "-U", "sa", "-P", "StrongPassword123!", "-Q", "SELECT 1"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: relspec-test-postgres
|
||||
|
||||
19
go.mod
19
go.mod
@@ -6,33 +6,46 @@ require (
|
||||
github.com/gdamore/tcell/v2 v2.8.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/microsoft/go-mssqldb v1.9.6
|
||||
github.com/rivo/tview v0.42.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/uptrace/bun v1.2.16
|
||||
golang.org/x/text v0.28.0
|
||||
golang.org/x/text v0.31.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.34.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
)
|
||||
|
||||
91
go.sum
91
go.sum
@@ -1,15 +1,39 @@
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
|
||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||
github.com/gdamore/tcell/v2 v2.8.1 h1:KPNxyqclpWpWQlPLx6Xui1pMk8S+7+R37h3g07997NU=
|
||||
github.com/gdamore/tcell/v2 v2.8.1/go.mod h1:bj8ori1BG3OYMjmb3IklZVWfZUJ1UBQt9JXrOCOhGWw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
@@ -26,15 +50,27 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/microsoft/go-mssqldb v1.9.6 h1:1MNQg5UiSsokiPz3++K2KPx4moKrwIqly1wv+RyCKTw=
|
||||
github.com/microsoft/go-mssqldb v1.9.6/go.mod h1:yYMPDufyoF2vVuVCUGtZARr06DKFIhMrluTcgWlXpr4=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/tview v0.42.0 h1:b/ftp+RxtDsHSaynXTbJb+/n/BxDEi+W3UfF5jILK6c=
|
||||
github.com/rivo/tview v0.42.0/go.mod h1:cSfIYfhpSGCjp3r/ECJb+GKS7cGJnqV8vfjQPwoXyfY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
@@ -45,6 +81,8 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
@@ -70,13 +108,17 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
@@ -85,6 +127,8 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -92,14 +136,15 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
@@ -116,8 +161,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
|
||||
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
||||
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
@@ -127,14 +172,16 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
@@ -142,3 +189,31 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
||||
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.44.3 h1:+39JvV/HWMcYslAwRxHb8067w+2zowvFOUrOWIy9PjY=
|
||||
modernc.org/sqlite v1.44.3/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
||||
714
pkg/commontypes/commontypes_test.go
Normal file
714
pkg/commontypes/commontypes_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package commontypes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractBaseType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
want string
|
||||
}{
|
||||
{"varchar with length", "varchar(100)", "varchar"},
|
||||
{"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"},
|
||||
{"numeric with precision", "numeric(10,2)", "numeric"},
|
||||
{"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"},
|
||||
{"decimal with precision", "decimal(15,3)", "decimal"},
|
||||
{"char with length", "char(50)", "char"},
|
||||
{"simple integer", "integer", "integer"},
|
||||
{"simple text", "text", "text"},
|
||||
{"bigint", "bigint", "bigint"},
|
||||
{"With spaces", " varchar(100) ", "varchar"},
|
||||
{"No parentheses", "boolean", "boolean"},
|
||||
{"Empty string", "", ""},
|
||||
{"Mixed case", "VarChar(100)", "varchar"},
|
||||
{"timestamp with time zone", "timestamp(6) with time zone", "timestamp"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ExtractBaseType(tt.sqlType)
|
||||
if got != tt.want {
|
||||
t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeType(t *testing.T) {
|
||||
// NormalizeType is an alias for ExtractBaseType, test that they behave the same
|
||||
testCases := []string{
|
||||
"varchar(100)",
|
||||
"numeric(10,2)",
|
||||
"integer",
|
||||
"text",
|
||||
" VARCHAR(255) ",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
extracted := ExtractBaseType(tc)
|
||||
normalized := NormalizeType(tc)
|
||||
if extracted != normalized {
|
||||
t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q",
|
||||
tc, extracted, tc, normalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToGo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int32"},
|
||||
{"bigint nullable", "bigint", true, "int64"},
|
||||
{"smallint nullable", "smallint", true, "int16"},
|
||||
{"serial nullable", "serial", true, "int32"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "*int32"},
|
||||
{"bigint not nullable", "bigint", false, "*int64"},
|
||||
{"smallint not nullable", "smallint", false, "*int16"},
|
||||
|
||||
// String types (nullable)
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar with length nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// String types (not nullable)
|
||||
{"text not nullable", "text", false, "*string"},
|
||||
{"varchar not nullable", "varchar", false, "*string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "*bool"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float32"},
|
||||
{"double precision nullable", "double precision", true, "float64"},
|
||||
{"real not nullable", "real", false, "*float32"},
|
||||
{"double precision not nullable", "double precision", false, "*float64"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "time.Time"},
|
||||
{"date nullable", "date", true, "time.Time"},
|
||||
{"timestamp not nullable", "timestamp", false, "*time.Time"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "[]byte"},
|
||||
{"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
{"uuid not nullable", "uuid", false, "*string"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "string"},
|
||||
{"jsonb nullable", "jsonb", true, "string"},
|
||||
|
||||
// Array
|
||||
{"array nullable", "array", true, "[]string"},
|
||||
{"array not nullable", "array", false, "[]string"}, // Slices don't get pointer
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "interface{}"},
|
||||
{"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToGo(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToTypeScript(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer nullable", "integer", true, "number"},
|
||||
{"integer not nullable", "integer", false, "number | null"},
|
||||
{"bigint nullable", "bigint", true, "number"},
|
||||
{"bigint not nullable", "bigint", false, "number | null"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "string | null"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "boolean"},
|
||||
{"boolean not nullable", "boolean", false, "boolean | null"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "number"},
|
||||
{"double precision nullable", "double precision", true, "number"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "Date"},
|
||||
{"date nullable", "date", true, "Date"},
|
||||
{"timestamp not nullable", "timestamp", false, "Date | null"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "Buffer"},
|
||||
{"bytea not nullable", "bytea", false, "Buffer | null"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "any"},
|
||||
{"jsonb nullable", "jsonb", true, "any"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "any"},
|
||||
{"unknown type not nullable", "unknowntype", false, "any | null"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToTypeScript(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToPython(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer", "integer", "int"},
|
||||
{"bigint", "bigint", "int"},
|
||||
{"smallint", "smallint", "int"},
|
||||
|
||||
// String types
|
||||
{"text", "text", "str"},
|
||||
{"varchar", "varchar", "str"},
|
||||
{"varchar(100)", "varchar(100)", "str"},
|
||||
|
||||
// Boolean
|
||||
{"boolean", "boolean", "bool"},
|
||||
|
||||
// Float types
|
||||
{"real", "real", "float"},
|
||||
{"double precision", "double precision", "float"},
|
||||
{"numeric", "numeric", "Decimal"},
|
||||
{"decimal", "decimal", "Decimal"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp", "timestamp", "datetime"},
|
||||
{"date", "date", "date"},
|
||||
{"time", "time", "time"},
|
||||
|
||||
// Binary
|
||||
{"bytea", "bytea", "bytes"},
|
||||
|
||||
// JSON
|
||||
{"json", "json", "dict"},
|
||||
{"jsonb", "jsonb", "dict"},
|
||||
|
||||
// UUID
|
||||
{"uuid", "uuid", "UUID"},
|
||||
|
||||
// Array
|
||||
{"array", "array", "list"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type", "unknowntype", "Any"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToPython(tt.sqlType)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToCSharp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int"},
|
||||
{"bigint nullable", "bigint", true, "long"},
|
||||
{"smallint nullable", "smallint", true, "short"},
|
||||
|
||||
// Integer types (not nullable - value types get ?)
|
||||
{"integer not nullable", "integer", false, "int?"},
|
||||
{"bigint not nullable", "bigint", false, "long?"},
|
||||
{"smallint not nullable", "smallint", false, "short?"},
|
||||
|
||||
// String types (reference types, no ? needed)
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "bool?"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float"},
|
||||
{"double precision nullable", "double precision", true, "double"},
|
||||
{"decimal nullable", "decimal", true, "decimal"},
|
||||
{"real not nullable", "real", false, "float?"},
|
||||
{"double precision not nullable", "double precision", false, "double?"},
|
||||
{"decimal not nullable", "decimal", false, "decimal?"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "DateTime"},
|
||||
{"date nullable", "date", true, "DateTime"},
|
||||
{"timestamptz nullable", "timestamptz", true, "DateTimeOffset"},
|
||||
{"timestamp not nullable", "timestamp", false, "DateTime?"},
|
||||
{"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"},
|
||||
|
||||
// Binary (array type, no ?)
|
||||
{"bytea nullable", "bytea", true, "byte[]"},
|
||||
{"bytea not nullable", "bytea", false, "byte[]"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "Guid"},
|
||||
{"uuid not nullable", "uuid", false, "Guid?"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "string"},
|
||||
|
||||
// Unknown types (object is reference type)
|
||||
{"unknown type nullable", "unknowntype", true, "object"},
|
||||
{"unknown type not nullable", "unknowntype", false, "object"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToCSharp(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsTimeImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
goType string
|
||||
want bool
|
||||
}{
|
||||
{"time.Time type", "time.Time", true},
|
||||
{"pointer to time.Time", "*time.Time", true},
|
||||
{"int32 type", "int32", false},
|
||||
{"string type", "string", false},
|
||||
{"bool type", "bool", false},
|
||||
{"[]byte type", "[]byte", false},
|
||||
{"interface{}", "interface{}", false},
|
||||
{"empty string", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NeedsTimeImport(tt.goType)
|
||||
if got != tt.want {
|
||||
t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoTypeMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int32",
|
||||
"bigint": "int64",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "float64",
|
||||
"bytea": "[]byte",
|
||||
"timestamp": "time.Time",
|
||||
"uuid": "string",
|
||||
"json": "string",
|
||||
}
|
||||
|
||||
for sqlType, expectedGoType := range expectedMappings {
|
||||
if goType, ok := GoTypeMap[sqlType]; !ok {
|
||||
t.Errorf("GoTypeMap missing entry for %q", sqlType)
|
||||
} else if goType != expectedGoType {
|
||||
t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(GoTypeMap) == 0 {
|
||||
t.Error("GoTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeScriptTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "number",
|
||||
"bigint": "number",
|
||||
"text": "string",
|
||||
"boolean": "boolean",
|
||||
"double precision": "number",
|
||||
"bytea": "Buffer",
|
||||
"timestamp": "Date",
|
||||
"uuid": "string",
|
||||
"json": "any",
|
||||
}
|
||||
|
||||
for sqlType, expectedTSType := range expectedMappings {
|
||||
if tsType, ok := TypeScriptTypeMap[sqlType]; !ok {
|
||||
t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType)
|
||||
} else if tsType != expectedTSType {
|
||||
t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(TypeScriptTypeMap) == 0 {
|
||||
t.Error("TypeScriptTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPythonTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int",
|
||||
"text": "str",
|
||||
"boolean": "bool",
|
||||
"real": "float",
|
||||
"numeric": "Decimal",
|
||||
"bytea": "bytes",
|
||||
"date": "date",
|
||||
"uuid": "UUID",
|
||||
"json": "dict",
|
||||
}
|
||||
|
||||
for sqlType, expectedPyType := range expectedMappings {
|
||||
if pyType, ok := PythonTypeMap[sqlType]; !ok {
|
||||
t.Errorf("PythonTypeMap missing entry for %q", sqlType)
|
||||
} else if pyType != expectedPyType {
|
||||
t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(PythonTypeMap) == 0 {
|
||||
t.Error("PythonTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSharpTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "long",
|
||||
"smallint": "short",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "double",
|
||||
"decimal": "decimal",
|
||||
"bytea": "byte[]",
|
||||
"timestamp": "DateTime",
|
||||
"uuid": "Guid",
|
||||
"json": "string",
|
||||
}
|
||||
|
||||
for sqlType, expectedCSType := range expectedMappings {
|
||||
if csType, ok := CSharpTypeMap[sqlType]; !ok {
|
||||
t.Errorf("CSharpTypeMap missing entry for %q", sqlType)
|
||||
} else if csType != expectedCSType {
|
||||
t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(CSharpTypeMap) == 0 {
|
||||
t.Error("CSharpTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToJava(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer nullable", "integer", true, "Integer"},
|
||||
{"integer not nullable", "integer", false, "Integer"},
|
||||
{"bigint nullable", "bigint", true, "Long"},
|
||||
{"smallint nullable", "smallint", true, "Short"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "String"},
|
||||
{"varchar nullable", "varchar", true, "String"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "Boolean"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "Float"},
|
||||
{"double precision nullable", "double precision", true, "Double"},
|
||||
{"numeric nullable", "numeric", true, "BigDecimal"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "Timestamp"},
|
||||
{"date nullable", "date", true, "Date"},
|
||||
{"time nullable", "time", true, "Time"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "byte[]"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "UUID"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "String"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "Object"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToJava(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToPhp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int"},
|
||||
{"bigint nullable", "bigint", true, "int"},
|
||||
{"smallint nullable", "smallint", true, "int"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "?int"},
|
||||
{"bigint not nullable", "bigint", false, "?int"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "?string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "?bool"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float"},
|
||||
{"double precision nullable", "double precision", true, "float"},
|
||||
{"real not nullable", "real", false, "?float"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "\\DateTime"},
|
||||
{"date nullable", "date", true, "\\DateTime"},
|
||||
{"timestamp not nullable", "timestamp", false, "?\\DateTime"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "string"},
|
||||
{"bytea not nullable", "bytea", false, "?string"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "array"},
|
||||
{"json not nullable", "json", false, "?array"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "mixed"},
|
||||
{"unknown type not nullable", "unknowntype", false, "mixed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToPhp(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToRust(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "i32"},
|
||||
{"bigint nullable", "bigint", true, "i64"},
|
||||
{"smallint nullable", "smallint", true, "i16"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "Option<i32>"},
|
||||
{"bigint not nullable", "bigint", false, "Option<i64>"},
|
||||
{"smallint not nullable", "smallint", false, "Option<i16>"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "String"},
|
||||
{"text not nullable", "text", false, "Option<String>"},
|
||||
{"varchar nullable", "varchar", true, "String"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "Option<bool>"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "f32"},
|
||||
{"double precision nullable", "double precision", true, "f64"},
|
||||
{"real not nullable", "real", false, "Option<f32>"},
|
||||
{"double precision not nullable", "double precision", false, "Option<f64>"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "NaiveDateTime"},
|
||||
{"timestamptz nullable", "timestamptz", true, "DateTime<Utc>"},
|
||||
{"date nullable", "date", true, "NaiveDate"},
|
||||
{"time nullable", "time", true, "NaiveTime"},
|
||||
{"timestamp not nullable", "timestamp", false, "Option<NaiveDateTime>"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "Vec<u8>"},
|
||||
{"bytea not nullable", "bytea", false, "Option<Vec<u8>>"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "serde_json::Value"},
|
||||
{"json not nullable", "json", false, "Option<serde_json::Value>"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "String"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "String"},
|
||||
{"unknown type not nullable", "unknowntype", false, "Option<String>"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToRust(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJavaTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "Integer",
|
||||
"bigint": "Long",
|
||||
"smallint": "Short",
|
||||
"text": "String",
|
||||
"boolean": "Boolean",
|
||||
"double precision": "Double",
|
||||
"numeric": "BigDecimal",
|
||||
"bytea": "byte[]",
|
||||
"timestamp": "Timestamp",
|
||||
"uuid": "UUID",
|
||||
"date": "Date",
|
||||
}
|
||||
|
||||
for sqlType, expectedJavaType := range expectedMappings {
|
||||
if javaType, ok := JavaTypeMap[sqlType]; !ok {
|
||||
t.Errorf("JavaTypeMap missing entry for %q", sqlType)
|
||||
} else if javaType != expectedJavaType {
|
||||
t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(JavaTypeMap) == 0 {
|
||||
t.Error("JavaTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPHPTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "float",
|
||||
"bytea": "string",
|
||||
"timestamp": "\\DateTime",
|
||||
"uuid": "string",
|
||||
"json": "array",
|
||||
}
|
||||
|
||||
for sqlType, expectedPHPType := range expectedMappings {
|
||||
if phpType, ok := PHPTypeMap[sqlType]; !ok {
|
||||
t.Errorf("PHPTypeMap missing entry for %q", sqlType)
|
||||
} else if phpType != expectedPHPType {
|
||||
t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(PHPTypeMap) == 0 {
|
||||
t.Error("PHPTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRustTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "i32",
|
||||
"bigint": "i64",
|
||||
"smallint": "i16",
|
||||
"text": "String",
|
||||
"boolean": "bool",
|
||||
"double precision": "f64",
|
||||
"real": "f32",
|
||||
"bytea": "Vec<u8>",
|
||||
"timestamp": "NaiveDateTime",
|
||||
"timestamptz": "DateTime<Utc>",
|
||||
"date": "NaiveDate",
|
||||
"json": "serde_json::Value",
|
||||
}
|
||||
|
||||
for sqlType, expectedRustType := range expectedMappings {
|
||||
if rustType, ok := RustTypeMap[sqlType]; !ok {
|
||||
t.Errorf("RustTypeMap missing entry for %q", sqlType)
|
||||
} else if rustType != expectedRustType {
|
||||
t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(RustTypeMap) == 0 {
|
||||
t.Error("RustTypeMap is empty")
|
||||
}
|
||||
}
|
||||
28
pkg/commontypes/doc.go
Normal file
28
pkg/commontypes/doc.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Package commontypes provides shared type definitions used across multiple packages.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The commontypes package contains common data structures, constants, and type
|
||||
// definitions that are shared between different parts of RelSpec but don't belong
|
||||
// to the core models package.
|
||||
//
|
||||
// # Purpose
|
||||
//
|
||||
// This package helps avoid circular dependencies by providing a common location
|
||||
// for types that are used by multiple packages without creating import cycles.
|
||||
//
|
||||
// # Contents
|
||||
//
|
||||
// Common types may include:
|
||||
// - Shared enums and constants
|
||||
// - Utility type aliases
|
||||
// - Common error types
|
||||
// - Shared configuration structures
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// import "git.warky.dev/wdevs/relspecgo/pkg/commontypes"
|
||||
//
|
||||
// // Use common types
|
||||
// var formatType commontypes.FormatType
|
||||
package commontypes
|
||||
558
pkg/diff/diff_test.go
Normal file
558
pkg/diff/diff_test.go
Normal file
@@ -0,0 +1,558 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestCompareDatabases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source *models.Database
|
||||
target *models.Database
|
||||
want func(*DiffResult) bool
|
||||
}{
|
||||
{
|
||||
name: "identical databases",
|
||||
source: &models.Database{
|
||||
Name: "source",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
target: &models.Database{
|
||||
Name: "target",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
want: func(r *DiffResult) bool {
|
||||
return r.Source == "source" && r.Target == "target" &&
|
||||
len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different schemas",
|
||||
source: &models.Database{
|
||||
Name: "source",
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
target: &models.Database{
|
||||
Name: "target",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
want: func(r *DiffResult) bool {
|
||||
return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CompareDatabases(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("CompareDatabases() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Column
|
||||
target map[string]*models.Column
|
||||
want func(*ColumnDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical columns",
|
||||
source: map[string]*models.Column{},
|
||||
target: map[string]*models.Column{},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing column",
|
||||
source: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
target: map[string]*models.Column{},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra column",
|
||||
source: map[string]*models.Column{},
|
||||
target: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column type",
|
||||
source: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint"},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "id" &&
|
||||
d.Modified[0].Changes["type"] != nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column nullable",
|
||||
source: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "text", NotNull: true},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "text", NotNull: false},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column length",
|
||||
source: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "varchar", Length: 100},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "varchar", Length: 255},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareColumns(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareColumns() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareColumnDetails(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source *models.Column
|
||||
target *models.Column
|
||||
want int // number of changes
|
||||
}{
|
||||
{
|
||||
name: "identical columns",
|
||||
source: &models.Column{Name: "id", Type: "integer"},
|
||||
target: &models.Column{Name: "id", Type: "integer"},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "type change",
|
||||
source: &models.Column{Name: "id", Type: "integer"},
|
||||
target: &models.Column{Name: "id", Type: "bigint"},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "length change",
|
||||
source: &models.Column{Name: "name", Type: "varchar", Length: 100},
|
||||
target: &models.Column{Name: "name", Type: "varchar", Length: 255},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "precision change",
|
||||
source: &models.Column{Name: "price", Type: "numeric", Precision: 10},
|
||||
target: &models.Column{Name: "price", Type: "numeric", Precision: 12},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "scale change",
|
||||
source: &models.Column{Name: "price", Type: "numeric", Scale: 2},
|
||||
target: &models.Column{Name: "price", Type: "numeric", Scale: 4},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "not null change",
|
||||
source: &models.Column{Name: "name", Type: "text", NotNull: true},
|
||||
target: &models.Column{Name: "name", Type: "text", NotNull: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "auto increment change",
|
||||
source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true},
|
||||
target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "primary key change",
|
||||
source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true},
|
||||
target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple changes",
|
||||
source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true},
|
||||
target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false},
|
||||
want: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareColumnDetails(tt.source, tt.target)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareIndexes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Index
|
||||
target map[string]*models.Index
|
||||
want func(*IndexDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical indexes",
|
||||
source: map[string]*models.Index{},
|
||||
target: map[string]*models.Index{},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing index",
|
||||
source: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||
},
|
||||
target: map[string]*models.Index{},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra index",
|
||||
source: map[string]*models.Index{},
|
||||
target: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||
},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified index uniqueness",
|
||||
source: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false},
|
||||
},
|
||||
target: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true},
|
||||
},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareIndexes(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareIndexes() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareConstraints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Constraint
|
||||
target map[string]*models.Constraint
|
||||
want func(*ConstraintDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical constraints",
|
||||
source: map[string]*models.Constraint{},
|
||||
target: map[string]*models.Constraint{},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing constraint",
|
||||
source: map[string]*models.Constraint{
|
||||
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||
},
|
||||
target: map[string]*models.Constraint{},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra constraint",
|
||||
source: map[string]*models.Constraint{},
|
||||
target: map[string]*models.Constraint{
|
||||
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||
},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareConstraints(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareConstraints() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareRelationships(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Relationship
|
||||
target map[string]*models.Relationship
|
||||
want func(*RelationshipDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical relationships",
|
||||
source: map[string]*models.Relationship{},
|
||||
target: map[string]*models.Relationship{},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing relationship",
|
||||
source: map[string]*models.Relationship{
|
||||
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||
},
|
||||
target: map[string]*models.Relationship{},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra relationship",
|
||||
source: map[string]*models.Relationship{},
|
||||
target: map[string]*models.Relationship{
|
||||
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||
},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareRelationships(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareRelationships() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source []*models.Table
|
||||
target []*models.Table
|
||||
want func(*TableDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical tables",
|
||||
source: []*models.Table{},
|
||||
target: []*models.Table{},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing table",
|
||||
source: []*models.Table{
|
||||
{Name: "users", Schema: "public"},
|
||||
},
|
||||
target: []*models.Table{},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "users"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra table",
|
||||
source: []*models.Table{},
|
||||
target: []*models.Table{
|
||||
{Name: "users", Schema: "public"},
|
||||
},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "users"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified table",
|
||||
source: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
target: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "users"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareTables(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareTables() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareSchemas(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source []*models.Schema
|
||||
target []*models.Schema
|
||||
want func(*SchemaDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical schemas",
|
||||
source: []*models.Schema{},
|
||||
target: []*models.Schema{},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing schema",
|
||||
source: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
target: []*models.Schema{},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "public"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra schema",
|
||||
source: []*models.Schema{},
|
||||
target: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "public"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareSchemas(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareSchemas() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v interface{}
|
||||
want bool
|
||||
}{
|
||||
{"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true},
|
||||
{"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false},
|
||||
{"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false},
|
||||
{"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true},
|
||||
{"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false},
|
||||
{"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true},
|
||||
{"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false},
|
||||
{"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true},
|
||||
{"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isEmpty(tt.v)
|
||||
if got != tt.want {
|
||||
t.Errorf("isEmpty() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeSummary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
want func(*Summary) bool
|
||||
}{
|
||||
{
|
||||
name: "empty diff",
|
||||
result: &DiffResult{
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
want: func(s *Summary) bool {
|
||||
return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "schemas with differences",
|
||||
result: &DiffResult{
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}},
|
||||
Modified: []*SchemaChange{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(s *Summary) bool {
|
||||
return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ComputeSummary(tt.result)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("ComputeSummary() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
43
pkg/diff/doc.go
Normal file
43
pkg/diff/doc.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Package diff provides utilities for comparing database schemas and identifying differences.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The diff package compares two database models at various granularity levels (database,
|
||||
// schema, table, column) and produces detailed reports of differences including:
|
||||
// - Missing items (present in source but not in target)
|
||||
// - Extra items (present in target but not in source)
|
||||
// - Modified items (present in both but with different properties)
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// Compare two databases and format the output:
|
||||
//
|
||||
// result := diff.CompareDatabases(sourceDB, targetDB)
|
||||
// err := diff.FormatDiff(result, diff.OutputFormatText, os.Stdout)
|
||||
//
|
||||
// # Output Formats
|
||||
//
|
||||
// The package supports multiple output formats:
|
||||
// - OutputFormatText: Human-readable text format
|
||||
// - OutputFormatJSON: Structured JSON output
|
||||
// - OutputFormatYAML: Structured YAML output
|
||||
//
|
||||
// # Comparison Scope
|
||||
//
|
||||
// The comparison covers:
|
||||
// - Schemas: Name, description, and contents
|
||||
// - Tables: Name, description, and all sub-elements
|
||||
// - Columns: Type, nullability, defaults, constraints
|
||||
// - Indexes: Columns, uniqueness, type
|
||||
// - Constraints: Type, columns, references
|
||||
// - Relationships: Type, from/to tables and columns
|
||||
// - Views: Definition and columns
|
||||
// - Sequences: Start value, increment, min/max values
|
||||
//
|
||||
// # Use Cases
|
||||
//
|
||||
// - Schema migration planning
|
||||
// - Database synchronization verification
|
||||
// - Change tracking and auditing
|
||||
// - CI/CD pipeline validation
|
||||
package diff
|
||||
440
pkg/diff/formatters_test.go
Normal file
440
pkg/diff/formatters_test.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestFormatDiff(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source_db",
|
||||
Target: "target_db",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
format OutputFormat
|
||||
wantErr bool
|
||||
}{
|
||||
{"summary format", FormatSummary, false},
|
||||
{"json format", FormatJSON, false},
|
||||
{"html format", FormatHTML, false},
|
||||
{"invalid format", OutputFormat("invalid"), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := FormatDiff(result, tt.format, &buf)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && buf.Len() == 0 {
|
||||
t.Error("FormatDiff() produced empty output")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
wantStr []string // strings that should appear in output
|
||||
}{
|
||||
{
|
||||
name: "no differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"source", "target", "No differences found"},
|
||||
},
|
||||
{
|
||||
name: "with schema differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{{Name: "schema2"}},
|
||||
Modified: []*SchemaChange{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||
},
|
||||
{
|
||||
name: "with table differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Missing: []*models.Table{{Name: "users"}},
|
||||
Extra: []*models.Table{{Name: "posts"}},
|
||||
Modified: []*TableChange{
|
||||
{Name: "comments", Schema: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(tt.result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range tt.wantStr {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatJSON(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatJSON(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatJSON() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if output is valid JSON
|
||||
var decoded DiffResult
|
||||
if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil {
|
||||
t.Errorf("formatJSON() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check basic structure
|
||||
if decoded.Source != "source" {
|
||||
t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source")
|
||||
}
|
||||
if decoded.Target != "target" {
|
||||
t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target")
|
||||
}
|
||||
if len(decoded.Schemas.Missing) != 1 {
|
||||
t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatHTML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
wantStr []string // HTML elements/content that should appear
|
||||
}{
|
||||
{
|
||||
name: "basic HTML structure",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"<!DOCTYPE html>",
|
||||
"<title>Database Diff Report</title>",
|
||||
"source",
|
||||
"target",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with schema differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "missing_schema"}},
|
||||
Extra: []*models.Schema{{Name: "extra_schema"}},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"<!DOCTYPE html>",
|
||||
"missing_schema",
|
||||
"extra_schema",
|
||||
"MISSING",
|
||||
"EXTRA",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with table modifications",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: &ColumnDiff{
|
||||
Missing: []*models.Column{{Name: "email", Type: "text"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"public",
|
||||
"users",
|
||||
"email",
|
||||
"text",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := formatHTML(tt.result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatHTML() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range tt.wantStr {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatHTML() output doesn't contain %q", want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithColumns(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: &ColumnDiff{
|
||||
Missing: []*models.Column{{Name: "email"}},
|
||||
Extra: []*models.Column{{Name: "phone"}, {Name: "address"}},
|
||||
Modified: []*ColumnChange{
|
||||
{Name: "name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
wantStrings := []string{
|
||||
"Columns:",
|
||||
"Missing: 1",
|
||||
"Extra: 2",
|
||||
"Modified: 1",
|
||||
}
|
||||
|
||||
for _, want := range wantStrings {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithIndexes(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Indexes: &IndexDiff{
|
||||
Missing: []*models.Index{{Name: "idx_email"}},
|
||||
Extra: []*models.Index{{Name: "idx_phone"}},
|
||||
Modified: []*IndexChange{{Name: "idx_name"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Indexes:") {
|
||||
t.Error("formatSummary() output doesn't contain Indexes section")
|
||||
}
|
||||
if !strings.Contains(output, "Missing: 1") {
|
||||
t.Error("formatSummary() output doesn't contain correct missing count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithConstraints(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Constraints: &ConstraintDiff{
|
||||
Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}},
|
||||
Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Constraints:") {
|
||||
t.Error("formatSummary() output doesn't contain Constraints section")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatJSONIndentation(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "test"}},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatJSON(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatJSON() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that JSON is indented (has newlines and spaces)
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "\n") {
|
||||
t.Error("formatJSON() should produce indented JSON with newlines")
|
||||
}
|
||||
if !strings.Contains(output, " ") {
|
||||
t.Error("formatJSON() should produce indented JSON with spaces")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputFormatConstants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
format OutputFormat
|
||||
want string
|
||||
}{
|
||||
{"summary constant", FormatSummary, "summary"},
|
||||
{"json constant", FormatJSON, "json"},
|
||||
{"html constant", FormatHTML, "html"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if string(tt.format) != tt.want {
|
||||
t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
40
pkg/inspector/doc.go
Normal file
40
pkg/inspector/doc.go
Normal file
@@ -0,0 +1,40 @@
|
||||
// Package inspector provides database introspection capabilities for live databases.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The inspector package contains utilities for connecting to live databases and
|
||||
// extracting their schema information through system catalog queries and metadata
|
||||
// inspection.
|
||||
//
|
||||
// # Features
|
||||
//
|
||||
// - Database connection management
|
||||
// - Schema metadata extraction
|
||||
// - Table structure analysis
|
||||
// - Constraint and index discovery
|
||||
// - Foreign key relationship mapping
|
||||
//
|
||||
// # Supported Databases
|
||||
//
|
||||
// - PostgreSQL (via pgx driver)
|
||||
// - SQLite (via modernc.org/sqlite driver)
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// This package is used internally by database readers (pgsql, sqlite) to perform
|
||||
// live schema introspection:
|
||||
//
|
||||
// inspector := inspector.NewPostgreSQLInspector(connString)
|
||||
// schemas, err := inspector.GetSchemas()
|
||||
// tables, err := inspector.GetTables(schemaName)
|
||||
//
|
||||
// # Architecture
|
||||
//
|
||||
// Each database type has its own inspector implementation that understands the
|
||||
// specific system catalogs and metadata structures of that database system.
|
||||
//
|
||||
// # Security
|
||||
//
|
||||
// Inspectors use read-only operations and never modify database structure.
|
||||
// Connection credentials should be handled securely.
|
||||
package inspector
|
||||
238
pkg/inspector/inspector_test.go
Normal file
238
pkg/inspector/inspector_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewInspector(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
if inspector == nil {
|
||||
t.Fatal("NewInspector() returned nil")
|
||||
}
|
||||
|
||||
if inspector.db != db {
|
||||
t.Error("NewInspector() database not set correctly")
|
||||
}
|
||||
|
||||
if inspector.config != config {
|
||||
t.Error("NewInspector() config not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspect(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report == nil {
|
||||
t.Fatal("Inspect() returned nil report")
|
||||
}
|
||||
|
||||
if report.Database != db.Name {
|
||||
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
|
||||
}
|
||||
|
||||
if report.Summary.TotalRules != len(config.Rules) {
|
||||
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
|
||||
}
|
||||
|
||||
if len(report.Violations) == 0 {
|
||||
t.Error("Inspect() returned no violations, expected some results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithDisabledRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Disable all rules
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 0 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
if report.Summary.RulesSkipped != len(config.Rules) {
|
||||
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithEnforcedRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Enable only one rule and enforce it
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
primaryKeyRule := config.Rules["primary_key_naming"]
|
||||
primaryKeyRule.Enabled = "enforce"
|
||||
primaryKeyRule.Pattern = "^id$"
|
||||
config.Rules["primary_key_naming"] = primaryKeyRule
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 1 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
// All results should be at error level for enforced rules
|
||||
for _, violation := range report.Violations {
|
||||
if violation.Level != "error" {
|
||||
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSummary(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
results := []ValidationResult{
|
||||
{RuleName: "rule1", Passed: true, Level: "error"},
|
||||
{RuleName: "rule2", Passed: false, Level: "error"},
|
||||
{RuleName: "rule3", Passed: false, Level: "warning"},
|
||||
{RuleName: "rule4", Passed: true, Level: "warning"},
|
||||
}
|
||||
|
||||
summary := inspector.generateSummary(results)
|
||||
|
||||
if summary.PassedCount != 2 {
|
||||
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
|
||||
}
|
||||
|
||||
if summary.ErrorCount != 1 {
|
||||
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
|
||||
}
|
||||
|
||||
if summary.WarningCount != 1 {
|
||||
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report *InspectorReport
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "with errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 5,
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "without errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 0,
|
||||
WarningCount: 3,
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.report.HasErrors(); got != tt.want {
|
||||
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
functionName string
|
||||
wantExists bool
|
||||
}{
|
||||
{"primary_key_naming", "primary_key_naming", true},
|
||||
{"primary_key_datatype", "primary_key_datatype", true},
|
||||
{"foreign_key_column_naming", "foreign_key_column_naming", true},
|
||||
{"table_regexpr", "table_regexpr", true},
|
||||
{"column_regexpr", "column_regexpr", true},
|
||||
{"reserved_words", "reserved_words", true},
|
||||
{"have_primary_key", "have_primary_key", true},
|
||||
{"orphaned_foreign_key", "orphaned_foreign_key", true},
|
||||
{"circular_dependency", "circular_dependency", true},
|
||||
{"unknown_function", "unknown_function", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, exists := getValidator(tt.functionName)
|
||||
if exists != tt.wantExists {
|
||||
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateResult(t *testing.T) {
|
||||
result := createResult(
|
||||
"test_rule",
|
||||
true,
|
||||
"Test message",
|
||||
"schema.table.column",
|
||||
map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
},
|
||||
)
|
||||
|
||||
if result.RuleName != "test_rule" {
|
||||
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
|
||||
}
|
||||
|
||||
if !result.Passed {
|
||||
t.Error("createResult() Passed = false, want true")
|
||||
}
|
||||
|
||||
if result.Message != "Test message" {
|
||||
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
|
||||
}
|
||||
|
||||
if result.Location != "schema.table.column" {
|
||||
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
|
||||
}
|
||||
|
||||
if len(result.Context) != 2 {
|
||||
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
|
||||
}
|
||||
}
|
||||
366
pkg/inspector/report_test.go
Normal file
366
pkg/inspector/report_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTestReport() *InspectorReport {
|
||||
return &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 8,
|
||||
RulesSkipped: 2,
|
||||
ErrorCount: 3,
|
||||
WarningCount: 5,
|
||||
PassedCount: 12,
|
||||
},
|
||||
Violations: []ValidationResult{
|
||||
{
|
||||
RuleName: "primary_key_naming",
|
||||
Level: "error",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
Location: "public.users.user_id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "user_id",
|
||||
"pattern": "^id_",
|
||||
},
|
||||
},
|
||||
{
|
||||
RuleName: "table_name_length",
|
||||
Level: "warning",
|
||||
Message: "Table name too long",
|
||||
Location: "public.very_long_table_name_that_exceeds_limits",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "very_long_table_name_that_exceeds_limits",
|
||||
"length": 44,
|
||||
"max_length": 32,
|
||||
},
|
||||
},
|
||||
},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMarkdownFormatter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewMarkdownFormatter() returned nil")
|
||||
}
|
||||
|
||||
// Buffer is not a terminal, so colors should be disabled
|
||||
if formatter.UseColors {
|
||||
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJSONFormatter(t *testing.T) {
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewJSONFormatter() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Check that output contains expected sections
|
||||
if !strings.Contains(output, "# RelSpec Inspector Report") {
|
||||
t.Error("Markdown output missing header")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Database:") {
|
||||
t.Error("Markdown output missing database field")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "testdb") {
|
||||
t.Error("Markdown output missing database name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Summary") {
|
||||
t.Error("Markdown output missing summary section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Rules Checked: 8") {
|
||||
t.Error("Markdown output missing rules checked count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Errors: 3") {
|
||||
t.Error("Markdown output missing error count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Warnings: 5") {
|
||||
t.Error("Markdown output missing warning count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Violations") {
|
||||
t.Error("Markdown output missing violations section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "primary_key_naming") {
|
||||
t.Error("Markdown output missing rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "public.users.user_id") {
|
||||
t.Error("Markdown output missing location")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
|
||||
report := &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 10,
|
||||
RulesSkipped: 0,
|
||||
ErrorCount: 0,
|
||||
WarningCount: 0,
|
||||
PassedCount: 50,
|
||||
},
|
||||
Violations: []ValidationResult{},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "No violations found") {
|
||||
t.Error("Markdown output should indicate no violations")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON
|
||||
var decoded InspectorReport
|
||||
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check key fields
|
||||
if decoded.Database != "testdb" {
|
||||
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
|
||||
}
|
||||
|
||||
if decoded.Summary.ErrorCount != 3 {
|
||||
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
|
||||
}
|
||||
|
||||
if len(decoded.Violations) != 2 {
|
||||
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
header := formatter.formatHeader("Test Header")
|
||||
|
||||
if !strings.Contains(header, "# Test Header") {
|
||||
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatBold(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Bold Text",
|
||||
wantContains: "**Bold Text**",
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Bold Text",
|
||||
wantContains: "Bold Text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.formatBold(tt.text)
|
||||
|
||||
if !strings.Contains(result, tt.wantContains) {
|
||||
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Colorize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
color string
|
||||
wantColor bool
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: false,
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.colorize(tt.text, tt.color)
|
||||
|
||||
hasColor := strings.Contains(result, tt.color)
|
||||
if hasColor != tt.wantColor {
|
||||
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, tt.text) {
|
||||
t.Errorf("colorize() doesn't contain original text %q", tt.text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatContext(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
context := map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "id",
|
||||
"pattern": "^id_",
|
||||
"max_length": 64,
|
||||
}
|
||||
|
||||
result := formatter.formatContext(context)
|
||||
|
||||
// Should not include schema, table, column (they're in location)
|
||||
if strings.Contains(result, "schema") {
|
||||
t.Error("formatContext() should skip schema field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "table=") {
|
||||
t.Error("formatContext() should skip table field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "column=") {
|
||||
t.Error("formatContext() should skip column field")
|
||||
}
|
||||
|
||||
// Should include other fields
|
||||
if !strings.Contains(result, "pattern") {
|
||||
t.Error("formatContext() should include pattern field")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "max_length") {
|
||||
t.Error("formatContext() should include max_length field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
violation := ValidationResult{
|
||||
RuleName: "test_rule",
|
||||
Level: "error",
|
||||
Message: "Test violation message",
|
||||
Location: "public.users.id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"pattern": "^id_",
|
||||
},
|
||||
}
|
||||
|
||||
result := formatter.formatViolation(violation, colorRed)
|
||||
|
||||
if !strings.Contains(result, "test_rule") {
|
||||
t.Error("formatViolation() should include rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Test violation message") {
|
||||
t.Error("formatViolation() should include message")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "public.users.id") {
|
||||
t.Error("formatViolation() should include location")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Location:") {
|
||||
t.Error("formatViolation() should include Location label")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Message:") {
|
||||
t.Error("formatViolation() should include Message label")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportFormatConstants(t *testing.T) {
|
||||
// Test that color constants are defined
|
||||
if colorReset == "" {
|
||||
t.Error("colorReset is not defined")
|
||||
}
|
||||
|
||||
if colorRed == "" {
|
||||
t.Error("colorRed is not defined")
|
||||
}
|
||||
|
||||
if colorYellow == "" {
|
||||
t.Error("colorYellow is not defined")
|
||||
}
|
||||
|
||||
if colorGreen == "" {
|
||||
t.Error("colorGreen is not defined")
|
||||
}
|
||||
|
||||
if colorBold == "" {
|
||||
t.Error("colorBold is not defined")
|
||||
}
|
||||
}
|
||||
249
pkg/inspector/rules_test.go
Normal file
249
pkg/inspector/rules_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDefaultConfig(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("GetDefaultConfig() returned nil")
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("GetDefaultConfig() returned no rules")
|
||||
}
|
||||
|
||||
// Check that all expected rules are present
|
||||
expectedRules := []string{
|
||||
"primary_key_naming",
|
||||
"primary_key_datatype",
|
||||
"primary_key_auto_increment",
|
||||
"foreign_key_column_naming",
|
||||
"foreign_key_constraint_naming",
|
||||
"foreign_key_index",
|
||||
"table_naming_case",
|
||||
"column_naming_case",
|
||||
"table_name_length",
|
||||
"column_name_length",
|
||||
"reserved_keywords",
|
||||
"missing_primary_key",
|
||||
"orphaned_foreign_key",
|
||||
"circular_dependency",
|
||||
}
|
||||
|
||||
for _, ruleName := range expectedRules {
|
||||
if _, exists := config.Rules[ruleName]; !exists {
|
||||
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_NonExistentFile(t *testing.T) {
|
||||
// Try to load a non-existent file
|
||||
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should return default config
|
||||
if config == nil {
|
||||
t.Fatal("LoadConfig() returned nil config for non-existent file")
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("LoadConfig() returned config with no rules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ValidFile(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test-config.yaml")
|
||||
|
||||
configContent := `version: "1.0"
|
||||
rules:
|
||||
primary_key_naming:
|
||||
enabled: "enforce"
|
||||
function: "primary_key_naming"
|
||||
pattern: "^pk_"
|
||||
message: "Primary keys must start with pk_"
|
||||
table_name_length:
|
||||
enabled: "warn"
|
||||
function: "table_name_length"
|
||||
max_length: 50
|
||||
message: "Table name too long"
|
||||
`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() returned error: %v", err)
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) != 2 {
|
||||
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
|
||||
}
|
||||
|
||||
// Check primary_key_naming rule
|
||||
pkRule, exists := config.Rules["primary_key_naming"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing primary_key_naming rule")
|
||||
}
|
||||
|
||||
if pkRule.Enabled != "enforce" {
|
||||
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
|
||||
}
|
||||
|
||||
if pkRule.Pattern != "^pk_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
|
||||
}
|
||||
|
||||
// Check table_name_length rule
|
||||
lengthRule, exists := config.Rules["table_name_length"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing table_name_length rule")
|
||||
}
|
||||
|
||||
if lengthRule.MaxLength != 50 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_InvalidYAML(t *testing.T) {
|
||||
// Create a temporary invalid config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
|
||||
|
||||
invalidContent := `invalid: yaml: content: {[}]`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configPath)
|
||||
if err == nil {
|
||||
t.Error("LoadConfig() with invalid YAML did not return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enabled",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is enabled",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "off is not enabled",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty is not enabled",
|
||||
rule: Rule{Enabled: ""},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnabled(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enforced",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is not enforced",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "off is not enforced",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnforced(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigRuleSettings(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Test specific rule settings
|
||||
pkNamingRule := config.Rules["primary_key_naming"]
|
||||
if pkNamingRule.Function != "primary_key_naming" {
|
||||
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
|
||||
}
|
||||
|
||||
if pkNamingRule.Pattern != "^id_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
|
||||
}
|
||||
|
||||
// Test datatype rule
|
||||
pkDatatypeRule := config.Rules["primary_key_datatype"]
|
||||
if len(pkDatatypeRule.AllowedTypes) == 0 {
|
||||
t.Error("primary_key_datatype has no allowed types")
|
||||
}
|
||||
|
||||
// Test length rule
|
||||
tableLengthRule := config.Rules["table_name_length"]
|
||||
if tableLengthRule.MaxLength != 64 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
|
||||
}
|
||||
|
||||
// Test reserved keywords rule
|
||||
reservedRule := config.Rules["reserved_keywords"]
|
||||
if !reservedRule.CheckTables {
|
||||
t.Error("reserved_keywords.CheckTables should be true")
|
||||
}
|
||||
if !reservedRule.CheckColumns {
|
||||
t.Error("reserved_keywords.CheckColumns should be true")
|
||||
}
|
||||
}
|
||||
837
pkg/inspector/validators_test.go
Normal file
837
pkg/inspector/validators_test.go
Normal file
@@ -0,0 +1,837 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// Helper function to create test database
|
||||
func createTestDatabase() *models.Database {
|
||||
return &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"username": {
|
||||
Name: "username",
|
||||
Type: "varchar(50)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
"rid_organization": {
|
||||
Name: "rid_organization",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_users_organization": {
|
||||
Name: "fk_users_organization",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "organizations",
|
||||
ReferencedSchema: "public",
|
||||
ReferencedColumns: []string{"id"},
|
||||
},
|
||||
},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_rid_organization": {
|
||||
Name: "idx_rid_organization",
|
||||
Columns: []string{"rid_organization"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "organizations",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(100)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern id",
|
||||
rule: Rule{
|
||||
Pattern: "^id$",
|
||||
Message: "Primary key should be 'id'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern id_",
|
||||
rule: Rule{
|
||||
Pattern: "^id_",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyDatatype(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "allowed type bigserial",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"bigserial", "bigint", "int"},
|
||||
Message: "Primary key should use integer types",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "disallowed type",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"uuid"},
|
||||
Message: "Primary key should use UUID",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "require auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: true,
|
||||
Message: "Primary key should have auto-increment",
|
||||
},
|
||||
wantLen: 0, // No violations - all PKs have auto-increment
|
||||
},
|
||||
{
|
||||
name: "disallow auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: false,
|
||||
Message: "Primary key should not have auto-increment",
|
||||
},
|
||||
wantLen: 2, // 2 violations - both PKs have auto-increment
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyColumnNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern rid_",
|
||||
rule: Rule{
|
||||
Pattern: "^rid_",
|
||||
Message: "Foreign key columns should start with 'rid_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key columns should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key constraints should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern FK_",
|
||||
rule: Rule{
|
||||
Pattern: "^FK_",
|
||||
Message: "Foreign key constraints should start with 'FK_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyIndex(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "require index with index present",
|
||||
rule: Rule{
|
||||
RequireIndex: true,
|
||||
Message: "Foreign key columns should have indexes",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "no requirement",
|
||||
rule: Rule{
|
||||
RequireIndex: false,
|
||||
Message: "Foreign key index check disabled",
|
||||
},
|
||||
wantLen: 0,
|
||||
wantPass: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Table names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[A-Z][A-Z0-9_]*$",
|
||||
Case: "uppercase",
|
||||
Message: "Table names should be uppercase",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Column names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 5, // 5 total columns across both tables
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "camelCase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-zA-Z0-9]*$",
|
||||
Case: "camelCase",
|
||||
Message: "Column names should be camelCase",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // rid_organization has underscore
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // Some columns exceed 5 chars
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReservedKeywords(t *testing.T) {
|
||||
// Create a database with reserved keywords
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "user", // "user" is a reserved keyword
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
"select": { // "select" is a reserved keyword
|
||||
Name: "select",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
checkPasses bool
|
||||
}{
|
||||
{
|
||||
name: "check tables only",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: false,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 1, // "user" table
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check columns only",
|
||||
rule: Rule{
|
||||
CheckTables: false,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 2, // "id", "select" columns (id passes, select fails)
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check both",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 3, // "user" table + "id", "select" columns
|
||||
checkPasses: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateReservedKeywords(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMissingPrimaryKey(t *testing.T) {
|
||||
// Create database with and without primary keys
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "with_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Table missing primary key",
|
||||
}
|
||||
|
||||
results := validateMissingPrimaryKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
|
||||
}
|
||||
|
||||
// First result should pass (with_pk has PK)
|
||||
if results[0].Passed != true {
|
||||
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
|
||||
}
|
||||
|
||||
// Second result should fail (without_pk missing PK)
|
||||
if results[1].Passed != false {
|
||||
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOrphanedForeignKey(t *testing.T) {
|
||||
// Create database with orphaned FK
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_nonexistent": {
|
||||
Name: "fk_nonexistent",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "nonexistent_table",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Foreign key references non-existent table",
|
||||
}
|
||||
|
||||
results := validateOrphanedForeignKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
|
||||
}
|
||||
|
||||
if results[0].Passed != false {
|
||||
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCircularDependency(t *testing.T) {
|
||||
// Create database with circular dependency
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "table_a",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_b": {
|
||||
Name: "fk_to_b",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_b",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "table_b",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_a": {
|
||||
Name: "fk_to_a",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_a",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Circular dependency detected",
|
||||
}
|
||||
|
||||
results := validateCircularDependency(db, rule, "test_rule")
|
||||
|
||||
// Should detect circular dependency in both tables
|
||||
if len(results) == 0 {
|
||||
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
if result.Passed {
|
||||
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeDataType(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"varchar(50)", "varchar"},
|
||||
{"decimal(10,2)", "decimal"},
|
||||
{"int", "int"},
|
||||
{"BIGINT", "bigint"},
|
||||
{"VARCHAR(255)", "varchar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := normalizeDataType(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
value string
|
||||
expected bool
|
||||
}{
|
||||
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
|
||||
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
|
||||
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
|
||||
{"empty slice", []string{}, "foo", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.value)
|
||||
if result != tt.expected {
|
||||
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
graph map[string][]string
|
||||
node string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "simple cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {},
|
||||
},
|
||||
node: "A",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "self cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
visited := make(map[string]bool)
|
||||
recStack := make(map[string]bool)
|
||||
result := hasCycle(tt.node, tt.graph, visited, recStack)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatLocation(t *testing.T) {
|
||||
tests := []struct {
|
||||
schema string
|
||||
table string
|
||||
column string
|
||||
expected string
|
||||
}{
|
||||
{"public", "users", "id", "public.users.id"},
|
||||
{"public", "users", "", "public.users"},
|
||||
{"public", "", "", "public"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := formatLocation(tt.schema, tt.table, tt.column)
|
||||
if result != tt.expected {
|
||||
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
|
||||
tt.schema, tt.table, tt.column, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,14 +12,16 @@ import (
|
||||
|
||||
// MergeResult represents the result of a merge operation
|
||||
type MergeResult struct {
|
||||
SchemasAdded int
|
||||
TablesAdded int
|
||||
ColumnsAdded int
|
||||
RelationsAdded int
|
||||
DomainsAdded int
|
||||
EnumsAdded int
|
||||
ViewsAdded int
|
||||
SequencesAdded int
|
||||
SchemasAdded int
|
||||
TablesAdded int
|
||||
ColumnsAdded int
|
||||
ConstraintsAdded int
|
||||
IndexesAdded int
|
||||
RelationsAdded int
|
||||
DomainsAdded int
|
||||
EnumsAdded int
|
||||
ViewsAdded int
|
||||
SequencesAdded int
|
||||
}
|
||||
|
||||
// MergeOptions contains options for merge operations
|
||||
@@ -120,8 +122,10 @@ func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema,
|
||||
}
|
||||
|
||||
if tgtTable, exists := existingTables[tableName]; exists {
|
||||
// Table exists, merge its columns
|
||||
// Table exists, merge its columns, constraints, and indexes
|
||||
r.mergeColumns(tgtTable, srcTable)
|
||||
r.mergeConstraints(tgtTable, srcTable)
|
||||
r.mergeIndexes(tgtTable, srcTable)
|
||||
} else {
|
||||
// Table doesn't exist, add it
|
||||
newTable := cloneTable(srcTable)
|
||||
@@ -151,6 +155,52 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize constraints map if nil
|
||||
if table.Constraints == nil {
|
||||
table.Constraints = make(map[string]*models.Constraint)
|
||||
}
|
||||
|
||||
// Create map of existing constraints
|
||||
existingConstraints := make(map[string]*models.Constraint)
|
||||
for constName := range table.Constraints {
|
||||
existingConstraints[constName] = table.Constraints[constName]
|
||||
}
|
||||
|
||||
// Merge constraints
|
||||
for constName, srcConst := range srcTable.Constraints {
|
||||
if _, exists := existingConstraints[constName]; !exists {
|
||||
// Constraint doesn't exist, add it
|
||||
newConst := cloneConstraint(srcConst)
|
||||
table.Constraints[constName] = newConst
|
||||
r.ConstraintsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize indexes map if nil
|
||||
if table.Indexes == nil {
|
||||
table.Indexes = make(map[string]*models.Index)
|
||||
}
|
||||
|
||||
// Create map of existing indexes
|
||||
existingIndexes := make(map[string]*models.Index)
|
||||
for idxName := range table.Indexes {
|
||||
existingIndexes[idxName] = table.Indexes[idxName]
|
||||
}
|
||||
|
||||
// Merge indexes
|
||||
for idxName, srcIdx := range srcTable.Indexes {
|
||||
if _, exists := existingIndexes[idxName]; !exists {
|
||||
// Index doesn't exist, add it
|
||||
newIdx := cloneIndex(srcIdx)
|
||||
table.Indexes[idxName] = newIdx
|
||||
r.IndexesAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
||||
// Create map of existing views
|
||||
existingViews := make(map[string]*models.View)
|
||||
@@ -552,6 +602,8 @@ func GetMergeSummary(result *MergeResult) string {
|
||||
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
|
||||
fmt.Sprintf("Tables added: %d", result.TablesAdded),
|
||||
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
|
||||
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
|
||||
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
|
||||
fmt.Sprintf("Views added: %d", result.ViewsAdded),
|
||||
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
|
||||
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
||||
@@ -560,6 +612,7 @@ func GetMergeSummary(result *MergeResult) string {
|
||||
}
|
||||
|
||||
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
||||
result.ConstraintsAdded + result.IndexesAdded +
|
||||
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
|
||||
result.RelationsAdded + result.DomainsAdded
|
||||
|
||||
|
||||
617
pkg/merge/merge_test.go
Normal file
617
pkg/merge/merge_test.go
Normal file
@@ -0,0 +1,617 @@
|
||||
package merge
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestMergeDatabases_NilInputs(t *testing.T) {
|
||||
result := MergeDatabases(nil, nil, nil)
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
if result.SchemasAdded != 0 {
|
||||
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDatabases_NewSchema(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "auth"},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.SchemasAdded != 1 {
|
||||
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
|
||||
}
|
||||
if len(target.Schemas) != 2 {
|
||||
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDatabases_ExistingSchema(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.SchemasAdded != 0 {
|
||||
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||
}
|
||||
if len(target.Schemas) != 1 {
|
||||
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeTables_NewTable(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "posts",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.TablesAdded != 1 {
|
||||
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables) != 2 {
|
||||
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeColumns_NewColumn(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "int"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"email": {Name: "email", Type: "varchar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ColumnsAdded != 1 {
|
||||
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ConstraintsAdded != 1 {
|
||||
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: nil, // Nil map
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ConstraintsAdded != 1 {
|
||||
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if target.Schemas[0].Tables[0].Constraints == nil {
|
||||
t.Error("Expected constraints map to be initialized")
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeIndexes_NewIndex(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: nil, // Nil map
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
if target.Schemas[0].Tables[0].Indexes == nil {
|
||||
t.Error("Expected indexes map to be initialized")
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeOptions_SkipTableNames(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "migrations",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
opts := &MergeOptions{
|
||||
SkipTableNames: map[string]bool{
|
||||
"migrations": true,
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, opts)
|
||||
if result.TablesAdded != 0 {
|
||||
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables) != 1 {
|
||||
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeViews_NewView(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Views: []*models.View{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Views: []*models.View{
|
||||
{
|
||||
Name: "user_summary",
|
||||
Schema: "public",
|
||||
Definition: "SELECT * FROM users",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ViewsAdded != 1 {
|
||||
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Views) != 1 {
|
||||
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeEnums_NewEnum(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Enums: []*models.Enum{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Enums: []*models.Enum{
|
||||
{
|
||||
Name: "user_role",
|
||||
Schema: "public",
|
||||
Values: []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.EnumsAdded != 1 {
|
||||
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Enums) != 1 {
|
||||
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDomains_NewDomain(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Domains: []*models.Domain{},
|
||||
}
|
||||
source := &models.Database{
|
||||
Domains: []*models.Domain{
|
||||
{
|
||||
Name: "auth",
|
||||
Description: "Authentication domain",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.DomainsAdded != 1 {
|
||||
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
|
||||
}
|
||||
if len(target.Domains) != 1 {
|
||||
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeRelations_NewRelation(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Relations: []*models.Relationship{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Relations: []*models.Relationship{
|
||||
{
|
||||
Name: "fk_posts_user",
|
||||
Type: models.OneToMany,
|
||||
FromTable: "posts",
|
||||
FromColumns: []string{"user_id"},
|
||||
ToTable: "users",
|
||||
ToColumns: []string{"id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.RelationsAdded != 1 {
|
||||
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Relations) != 1 {
|
||||
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMergeSummary(t *testing.T) {
|
||||
result := &MergeResult{
|
||||
SchemasAdded: 1,
|
||||
TablesAdded: 2,
|
||||
ColumnsAdded: 5,
|
||||
ConstraintsAdded: 3,
|
||||
IndexesAdded: 2,
|
||||
ViewsAdded: 1,
|
||||
}
|
||||
|
||||
summary := GetMergeSummary(result)
|
||||
if summary == "" {
|
||||
t.Error("Expected non-empty summary")
|
||||
}
|
||||
if len(summary) < 50 {
|
||||
t.Errorf("Summary seems too short: %s", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMergeSummary_Nil(t *testing.T) {
|
||||
summary := GetMergeSummary(nil)
|
||||
if summary == "" {
|
||||
t.Error("Expected non-empty summary for nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexMerge(t *testing.T) {
|
||||
// Target with existing structure
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "int"},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{},
|
||||
Indexes: map[string]*models.Index{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Source with new columns, constraints, and indexes
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"email": {Name: "email", Type: "varchar"},
|
||||
"guid": {Name: "guid", Type: "uuid"},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
"ukey_users_guid": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"guid"},
|
||||
Name: "ukey_users_guid",
|
||||
},
|
||||
},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
|
||||
// Verify counts
|
||||
if result.ColumnsAdded != 2 {
|
||||
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
|
||||
}
|
||||
if result.ConstraintsAdded != 2 {
|
||||
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
|
||||
// Verify target has merged data
|
||||
table := target.Schemas[0].Tables[0]
|
||||
if len(table.Columns) != 3 {
|
||||
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
|
||||
}
|
||||
if len(table.Constraints) != 2 {
|
||||
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
|
||||
}
|
||||
if len(table.Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
|
||||
}
|
||||
|
||||
// Verify specific constraint
|
||||
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
|
||||
t.Error("Expected ukey_users_guid constraint to exist")
|
||||
}
|
||||
}
|
||||
99
pkg/mssql/README.md
Normal file
99
pkg/mssql/README.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# MSSQL Package
|
||||
|
||||
Provides utilities for working with Microsoft SQL Server data types and conversions.
|
||||
|
||||
## Components
|
||||
|
||||
### Type Mapping
|
||||
|
||||
Provides bidirectional conversion between canonical types and MSSQL types:
|
||||
|
||||
- **CanonicalToMSSQL**: Convert abstract types to MSSQL-specific types
|
||||
- **MSSQLToCanonical**: Convert MSSQL types to abstract representation
|
||||
|
||||
## Type Conversion Tables
|
||||
|
||||
### Canonical → MSSQL
|
||||
|
||||
| Canonical | MSSQL | Notes |
|
||||
|-----------|-------|-------|
|
||||
| int | INT | 32-bit signed integer |
|
||||
| int64 | BIGINT | 64-bit signed integer |
|
||||
| int32 | INT | 32-bit signed integer |
|
||||
| int16 | SMALLINT | 16-bit signed integer |
|
||||
| int8 | TINYINT | 8-bit unsigned integer |
|
||||
| bool | BIT | 0 (false) or 1 (true) |
|
||||
| float32 | REAL | Single precision floating point |
|
||||
| float64 | FLOAT | Double precision floating point |
|
||||
| decimal | NUMERIC | Fixed-point decimal number |
|
||||
| string | NVARCHAR(255) | Unicode variable-length string |
|
||||
| text | NVARCHAR(MAX) | Unicode large text |
|
||||
| timestamp | DATETIME2 | Date and time without timezone |
|
||||
| timestamptz | DATETIMEOFFSET | Date and time with timezone offset |
|
||||
| uuid | UNIQUEIDENTIFIER | GUID/UUID type |
|
||||
| bytea | VARBINARY(MAX) | Variable-length binary data |
|
||||
| date | DATE | Date only |
|
||||
| time | TIME | Time only |
|
||||
| json | NVARCHAR(MAX) | Stored as text (MSSQL v2016+) |
|
||||
| jsonb | NVARCHAR(MAX) | Stored as text (MSSQL v2016+) |
|
||||
|
||||
### MSSQL → Canonical
|
||||
|
||||
| MSSQL | Canonical | Notes |
|
||||
|-------|-----------|-------|
|
||||
| INT, INTEGER | int | Standard integer |
|
||||
| BIGINT | int64 | Large integer |
|
||||
| SMALLINT | int16 | Small integer |
|
||||
| TINYINT | int8 | Tiny integer |
|
||||
| BIT | bool | Boolean/bit flag |
|
||||
| REAL | float32 | Single precision |
|
||||
| FLOAT | float64 | Double precision |
|
||||
| NUMERIC, DECIMAL | decimal | Exact decimal |
|
||||
| NVARCHAR, VARCHAR | string | Variable-length string |
|
||||
| NCHAR, CHAR | string | Fixed-length string |
|
||||
| DATETIME2 | timestamp | Default timestamp |
|
||||
| DATETIMEOFFSET | timestamptz | Timestamp with timezone |
|
||||
| DATE | date | Date only |
|
||||
| TIME | time | Time only |
|
||||
| UNIQUEIDENTIFIER | uuid | UUID/GUID |
|
||||
| VARBINARY, BINARY | bytea | Binary data |
|
||||
| XML | string | Stored as text |
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/mssql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Convert canonical to MSSQL
|
||||
mssqlType := mssql.ConvertCanonicalToMSSQL("int")
|
||||
fmt.Println(mssqlType) // Output: INT
|
||||
|
||||
// Convert MSSQL to canonical
|
||||
canonicalType := mssql.ConvertMSSQLToCanonical("BIGINT")
|
||||
fmt.Println(canonicalType) // Output: int64
|
||||
|
||||
// Handle parameterized types
|
||||
canonicalType = mssql.ConvertMSSQLToCanonical("NVARCHAR(255)")
|
||||
fmt.Println(canonicalType) // Output: string
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests with:
|
||||
```bash
|
||||
go test ./pkg/mssql/...
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Type conversions are case-insensitive
|
||||
- Parameterized types (e.g., `NVARCHAR(255)`) have their base type extracted
|
||||
- Unmapped types default to `string` for safety
|
||||
- The package supports SQL Server 2016 and later versions
|
||||
114
pkg/mssql/datatypes.go
Normal file
114
pkg/mssql/datatypes.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package mssql
|
||||
|
||||
import "strings"
|
||||
|
||||
// CanonicalToMSSQLTypes maps canonical types to MSSQL types
|
||||
var CanonicalToMSSQLTypes = map[string]string{
|
||||
"bool": "BIT",
|
||||
"int8": "TINYINT",
|
||||
"int16": "SMALLINT",
|
||||
"int": "INT",
|
||||
"int32": "INT",
|
||||
"int64": "BIGINT",
|
||||
"uint": "BIGINT",
|
||||
"uint8": "SMALLINT",
|
||||
"uint16": "INT",
|
||||
"uint32": "BIGINT",
|
||||
"uint64": "BIGINT",
|
||||
"float32": "REAL",
|
||||
"float64": "FLOAT",
|
||||
"decimal": "NUMERIC",
|
||||
"string": "NVARCHAR(255)",
|
||||
"text": "NVARCHAR(MAX)",
|
||||
"date": "DATE",
|
||||
"time": "TIME",
|
||||
"timestamp": "DATETIME2",
|
||||
"timestamptz": "DATETIMEOFFSET",
|
||||
"uuid": "UNIQUEIDENTIFIER",
|
||||
"json": "NVARCHAR(MAX)",
|
||||
"jsonb": "NVARCHAR(MAX)",
|
||||
"bytea": "VARBINARY(MAX)",
|
||||
}
|
||||
|
||||
// MSSQLToCanonicalTypes maps MSSQL types to canonical types
|
||||
var MSSQLToCanonicalTypes = map[string]string{
|
||||
"bit": "bool",
|
||||
"tinyint": "int8",
|
||||
"smallint": "int16",
|
||||
"int": "int",
|
||||
"integer": "int",
|
||||
"bigint": "int64",
|
||||
"real": "float32",
|
||||
"float": "float64",
|
||||
"numeric": "decimal",
|
||||
"decimal": "decimal",
|
||||
"money": "decimal",
|
||||
"smallmoney": "decimal",
|
||||
"nvarchar": "string",
|
||||
"nchar": "string",
|
||||
"varchar": "string",
|
||||
"char": "string",
|
||||
"text": "string",
|
||||
"ntext": "string",
|
||||
"date": "date",
|
||||
"time": "time",
|
||||
"datetime": "timestamp",
|
||||
"datetime2": "timestamp",
|
||||
"smalldatetime": "timestamp",
|
||||
"datetimeoffset": "timestamptz",
|
||||
"uniqueidentifier": "uuid",
|
||||
"varbinary": "bytea",
|
||||
"binary": "bytea",
|
||||
"image": "bytea",
|
||||
"xml": "string",
|
||||
"json": "json",
|
||||
"sql_variant": "string",
|
||||
"hierarchyid": "string",
|
||||
"geography": "string",
|
||||
"geometry": "string",
|
||||
}
|
||||
|
||||
// ConvertCanonicalToMSSQL converts a canonical type to MSSQL type
|
||||
func ConvertCanonicalToMSSQL(canonicalType string) string {
|
||||
// Check direct mapping
|
||||
if mssqlType, exists := CanonicalToMSSQLTypes[strings.ToLower(canonicalType)]; exists {
|
||||
return mssqlType
|
||||
}
|
||||
|
||||
// Try to find by prefix
|
||||
lowerType := strings.ToLower(canonicalType)
|
||||
for canonical, mssql := range CanonicalToMSSQLTypes {
|
||||
if strings.HasPrefix(lowerType, canonical) {
|
||||
return mssql
|
||||
}
|
||||
}
|
||||
|
||||
// Default to NVARCHAR
|
||||
return "NVARCHAR(255)"
|
||||
}
|
||||
|
||||
// ConvertMSSQLToCanonical converts an MSSQL type to canonical type
|
||||
func ConvertMSSQLToCanonical(mssqlType string) string {
|
||||
// Extract base type (remove parentheses and parameters)
|
||||
baseType := mssqlType
|
||||
if idx := strings.Index(baseType, "("); idx != -1 {
|
||||
baseType = baseType[:idx]
|
||||
}
|
||||
baseType = strings.TrimSpace(baseType)
|
||||
|
||||
// Check direct mapping
|
||||
if canonicalType, exists := MSSQLToCanonicalTypes[strings.ToLower(baseType)]; exists {
|
||||
return canonicalType
|
||||
}
|
||||
|
||||
// Try to find by prefix
|
||||
lowerType := strings.ToLower(baseType)
|
||||
for mssql, canonical := range MSSQLToCanonicalTypes {
|
||||
if strings.HasPrefix(lowerType, mssql) {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return "string"
|
||||
}
|
||||
339
pkg/pgsql/datatypes_test.go
Normal file
339
pkg/pgsql/datatypes_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqltype string
|
||||
want bool
|
||||
}{
|
||||
// PostgreSQL types
|
||||
{"Valid PGSQL bigint", "bigint", true},
|
||||
{"Valid PGSQL integer", "integer", true},
|
||||
{"Valid PGSQL text", "text", true},
|
||||
{"Valid PGSQL boolean", "boolean", true},
|
||||
{"Valid PGSQL double precision", "double precision", true},
|
||||
{"Valid PGSQL bytea", "bytea", true},
|
||||
{"Valid PGSQL uuid", "uuid", true},
|
||||
{"Valid PGSQL jsonb", "jsonb", true},
|
||||
{"Valid PGSQL json", "json", true},
|
||||
{"Valid PGSQL timestamp", "timestamp", true},
|
||||
{"Valid PGSQL date", "date", true},
|
||||
{"Valid PGSQL time", "time", true},
|
||||
{"Valid PGSQL citext", "citext", true},
|
||||
|
||||
// Standard types
|
||||
{"Valid std double", "double", true},
|
||||
{"Valid std blob", "blob", true},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BIGINT", "BIGINT", true},
|
||||
{"Case insensitive TeXt", "TeXt", true},
|
||||
{"Case insensitive BoOlEaN", "BoOlEaN", true},
|
||||
|
||||
// Invalid types
|
||||
{"Invalid type", "invalidtype", false},
|
||||
{"Invalid type varchar", "varchar", false},
|
||||
{"Empty string", "", false},
|
||||
{"Random string", "foobar", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ValidSQLType(tt.sqltype)
|
||||
if got != tt.want {
|
||||
t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anytype string
|
||||
want string
|
||||
}{
|
||||
// Go types to PostgreSQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double precision", "float64", "double precision"},
|
||||
{"Go float32 to real", "float32", "real"},
|
||||
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||
|
||||
// SQL types remain SQL types
|
||||
{"SQL bigint", "bigint", "bigint"},
|
||||
{"SQL integer", "integer", "integer"},
|
||||
{"SQL text", "text", "text"},
|
||||
{"SQL boolean", "boolean", "boolean"},
|
||||
{"SQL uuid", "uuid", "uuid"},
|
||||
{"SQL jsonb", "jsonb", "jsonb"},
|
||||
|
||||
// Case insensitive Go types
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
{"Case insensitive STRING", "STRING", "text"},
|
||||
|
||||
// Case insensitive SQL types
|
||||
{"Case insensitive BIGINT", "BIGINT", "bigint"},
|
||||
{"Case insensitive TEXT", "TEXT", "text"},
|
||||
|
||||
// Custom types
|
||||
{"Custom sqluuid", "sqluuid", "uuid"},
|
||||
{"Custom sqljsonb", "sqljsonb", "jsonb"},
|
||||
{"Custom sqlint64", "sqlint64", "bigint"},
|
||||
|
||||
// Unknown types default to text
|
||||
{"Unknown type varchar", "varchar", "text"},
|
||||
{"Unknown type foobar", "foobar", "text"},
|
||||
{"Empty string", "", "text"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetSQLType(tt.anytype)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anytype string
|
||||
want string
|
||||
}{
|
||||
// Go types to PostgreSQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double precision", "float64", "double precision"},
|
||||
{"Go float32 to real", "float32", "real"},
|
||||
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||
|
||||
// SQL types remain SQL types
|
||||
{"SQL bigint", "bigint", "bigint"},
|
||||
{"SQL integer", "integer", "integer"},
|
||||
{"SQL text", "text", "text"},
|
||||
{"SQL boolean", "boolean", "boolean"},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
|
||||
// Unknown types remain unchanged (difference from GetSQLType)
|
||||
{"Unknown type varchar", "varchar", "varchar"},
|
||||
{"Unknown type foobar", "foobar", "foobar"},
|
||||
{"Empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ConvertSQLType(tt.anytype)
|
||||
if got != tt.want {
|
||||
t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGoType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
typeName string
|
||||
want bool
|
||||
}{
|
||||
// Go basic types
|
||||
{"Go bool", "bool", true},
|
||||
{"Go int64", "int64", true},
|
||||
{"Go int", "int", true},
|
||||
{"Go int32", "int32", true},
|
||||
{"Go int16", "int16", true},
|
||||
{"Go int8", "int8", true},
|
||||
{"Go uint", "uint", true},
|
||||
{"Go uint64", "uint64", true},
|
||||
{"Go uint32", "uint32", true},
|
||||
{"Go uint16", "uint16", true},
|
||||
{"Go uint8", "uint8", true},
|
||||
{"Go float64", "float64", true},
|
||||
{"Go float32", "float32", true},
|
||||
{"Go string", "string", true},
|
||||
{"Go []byte", "[]byte", true},
|
||||
|
||||
// Go custom types
|
||||
{"Go complex64", "complex64", true},
|
||||
{"Go complex128", "complex128", true},
|
||||
{"Go uintptr", "uintptr", true},
|
||||
{"Go Pointer", "Pointer", true},
|
||||
|
||||
// Custom SQL types
|
||||
{"Custom sqluuid", "sqluuid", true},
|
||||
{"Custom sqljsonb", "sqljsonb", true},
|
||||
{"Custom sqlint64", "sqlint64", true},
|
||||
{"Custom customdate", "customdate", true},
|
||||
{"Custom customtime", "customtime", true},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", true},
|
||||
{"Case insensitive InT64", "InT64", true},
|
||||
{"Case insensitive STRING", "STRING", true},
|
||||
|
||||
// SQL types (not Go types)
|
||||
{"SQL bigint", "bigint", false},
|
||||
{"SQL integer", "integer", false},
|
||||
{"SQL text", "text", false},
|
||||
{"SQL boolean", "boolean", false},
|
||||
|
||||
// Invalid types
|
||||
{"Invalid type", "invalidtype", false},
|
||||
{"Empty string", "", false},
|
||||
{"Random string", "foobar", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsGoType(tt.typeName)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStdTypeFromGo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
typeName string
|
||||
want string
|
||||
}{
|
||||
// Go types to standard SQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double", "float64", "double"},
|
||||
{"Go float32 to double", "float32", "double"},
|
||||
{"Go []byte to blob", "[]byte", "blob"},
|
||||
{"Go int32 to integer", "int32", "integer"},
|
||||
{"Go int16 to smallint", "int16", "smallint"},
|
||||
|
||||
// Custom types
|
||||
{"Custom sqluuid to uuid", "sqluuid", "uuid"},
|
||||
{"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"},
|
||||
{"Custom sqlint64 to bigint", "sqlint64", "bigint"},
|
||||
{"Custom customdate to date", "customdate", "date"},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
{"Case insensitive STRING", "STRING", "text"},
|
||||
|
||||
// Non-Go types remain unchanged
|
||||
{"SQL bigint unchanged", "bigint", "bigint"},
|
||||
{"SQL integer unchanged", "integer", "integer"},
|
||||
{"Invalid type unchanged", "invalidtype", "invalidtype"},
|
||||
{"Empty string unchanged", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetStdTypeFromGo(tt.typeName)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoToStdTypesMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"string": "text",
|
||||
"float64": "double",
|
||||
"[]byte": "blob",
|
||||
}
|
||||
|
||||
for goType, expectedStd := range expectedMappings {
|
||||
if stdType, ok := GoToStdTypes[goType]; !ok {
|
||||
t.Errorf("GoToStdTypes missing entry for %q", goType)
|
||||
} else if stdType != expectedStd {
|
||||
t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the map is not empty
|
||||
if len(GoToStdTypes) == 0 {
|
||||
t.Error("GoToStdTypes map is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoToPGSQLTypesMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"string": "text",
|
||||
"float64": "double precision",
|
||||
"float32": "real",
|
||||
"[]byte": "bytea",
|
||||
}
|
||||
|
||||
for goType, expectedPG := range expectedMappings {
|
||||
if pgType, ok := GoToPGSQLTypes[goType]; !ok {
|
||||
t.Errorf("GoToPGSQLTypes missing entry for %q", goType)
|
||||
} else if pgType != expectedPG {
|
||||
t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the map is not empty
|
||||
if len(GoToPGSQLTypes) == 0 {
|
||||
t.Error("GoToPGSQLTypes map is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeConversionConsistency(t *testing.T) {
|
||||
// Test that GetSQLType and ConvertSQLType are consistent for known types
|
||||
knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"}
|
||||
|
||||
for _, goType := range knownGoTypes {
|
||||
getSQLResult := GetSQLType(goType)
|
||||
convertResult := ConvertSQLType(goType)
|
||||
|
||||
if getSQLResult != convertResult {
|
||||
t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q",
|
||||
goType, getSQLResult, convertResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) {
|
||||
// Test that GetSQLType returns "text" for unknown types
|
||||
// while ConvertSQLType returns the input unchanged
|
||||
unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"}
|
||||
|
||||
for _, unknown := range unknownTypes {
|
||||
getSQLResult := GetSQLType(unknown)
|
||||
convertResult := ConvertSQLType(unknown)
|
||||
|
||||
if getSQLResult != "text" {
|
||||
t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text")
|
||||
}
|
||||
|
||||
if convertResult != unknown {
|
||||
t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
36
pkg/pgsql/doc.go
Normal file
36
pkg/pgsql/doc.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Package pgsql provides PostgreSQL-specific utilities and helpers.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The pgsql package contains PostgreSQL-specific functionality including:
|
||||
// - SQL reserved keyword validation
|
||||
// - Data type mappings and conversions
|
||||
// - PostgreSQL-specific schema introspection helpers
|
||||
//
|
||||
// # Components
|
||||
//
|
||||
// keywords.go - SQL reserved keywords validation
|
||||
//
|
||||
// Provides functions to check if identifiers conflict with SQL reserved words
|
||||
// and need quoting for safe usage in PostgreSQL queries.
|
||||
//
|
||||
// datatypes.go - PostgreSQL data type utilities
|
||||
//
|
||||
// Contains mappings between PostgreSQL data types and their equivalents in other
|
||||
// systems, as well as type conversion and normalization functions.
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// // Check if identifier needs quoting
|
||||
// if pgsql.IsReservedKeyword("user") {
|
||||
// // Quote the identifier
|
||||
// }
|
||||
//
|
||||
// // Normalize data type
|
||||
// normalizedType := pgsql.NormalizeDataType("varchar(255)")
|
||||
//
|
||||
// # Purpose
|
||||
//
|
||||
// This package supports the PostgreSQL reader and writer implementations by providing
|
||||
// shared utilities for handling PostgreSQL-specific schema elements and constraints.
|
||||
package pgsql
|
||||
136
pkg/pgsql/keywords_test.go
Normal file
136
pkg/pgsql/keywords_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetPostgresKeywords(t *testing.T) {
|
||||
keywords := GetPostgresKeywords()
|
||||
|
||||
// Test that keywords are returned
|
||||
if len(keywords) == 0 {
|
||||
t.Fatal("Expected non-empty list of keywords")
|
||||
}
|
||||
|
||||
// Test that we get all keywords from the map
|
||||
expectedCount := len(postgresKeywords)
|
||||
if len(keywords) != expectedCount {
|
||||
t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords))
|
||||
}
|
||||
|
||||
// Test that all returned keywords exist in the map
|
||||
for _, keyword := range keywords {
|
||||
if !postgresKeywords[keyword] {
|
||||
t.Errorf("Keyword %q not found in postgresKeywords map", keyword)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that no duplicate keywords are returned
|
||||
seen := make(map[string]bool)
|
||||
for _, keyword := range keywords {
|
||||
if seen[keyword] {
|
||||
t.Errorf("Duplicate keyword found: %q", keyword)
|
||||
}
|
||||
seen[keyword] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyword string
|
||||
want bool
|
||||
}{
|
||||
{"SELECT keyword", "select", true},
|
||||
{"FROM keyword", "from", true},
|
||||
{"WHERE keyword", "where", true},
|
||||
{"TABLE keyword", "table", true},
|
||||
{"PRIMARY keyword", "primary", true},
|
||||
{"FOREIGN keyword", "foreign", true},
|
||||
{"CREATE keyword", "create", true},
|
||||
{"DROP keyword", "drop", true},
|
||||
{"ALTER keyword", "alter", true},
|
||||
{"INDEX keyword", "index", true},
|
||||
{"NOT keyword", "not", true},
|
||||
{"NULL keyword", "null", true},
|
||||
{"TRUE keyword", "true", true},
|
||||
{"FALSE keyword", "false", true},
|
||||
{"Non-keyword lowercase", "notakeyword", false},
|
||||
{"Non-keyword uppercase", "NOTAKEYWORD", false},
|
||||
{"Empty string", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := postgresKeywords[tt.keyword]
|
||||
if got != tt.want {
|
||||
t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMapContent(t *testing.T) {
|
||||
// Test that the map contains expected common keywords
|
||||
commonKeywords := []string{
|
||||
"select", "insert", "update", "delete", "create", "drop", "alter",
|
||||
"table", "index", "view", "schema", "function", "procedure",
|
||||
"primary", "foreign", "key", "constraint", "unique", "check",
|
||||
"null", "not", "and", "or", "like", "in", "between",
|
||||
"join", "inner", "left", "right", "cross", "full", "outer",
|
||||
"where", "having", "group", "order", "limit", "offset",
|
||||
"union", "intersect", "except",
|
||||
"begin", "commit", "rollback", "transaction",
|
||||
}
|
||||
|
||||
for _, keyword := range commonKeywords {
|
||||
if !postgresKeywords[keyword] {
|
||||
t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMapSize(t *testing.T) {
|
||||
// PostgreSQL has a substantial list of reserved keywords
|
||||
// This test ensures the map has a reasonable number of entries
|
||||
minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words
|
||||
|
||||
if len(postgresKeywords) < minExpectedKeywords {
|
||||
t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.",
|
||||
minExpectedKeywords, len(postgresKeywords))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPostgresKeywordsConsistency(t *testing.T) {
|
||||
// Test that calling GetPostgresKeywords multiple times returns consistent results
|
||||
keywords1 := GetPostgresKeywords()
|
||||
keywords2 := GetPostgresKeywords()
|
||||
|
||||
if len(keywords1) != len(keywords2) {
|
||||
t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d",
|
||||
len(keywords1), len(keywords2))
|
||||
}
|
||||
|
||||
// Create a map from both results to compare
|
||||
map1 := make(map[string]bool)
|
||||
map2 := make(map[string]bool)
|
||||
|
||||
for _, k := range keywords1 {
|
||||
map1[k] = true
|
||||
}
|
||||
for _, k := range keywords2 {
|
||||
map2[k] = true
|
||||
}
|
||||
|
||||
// Check that both contain the same keywords
|
||||
for k := range map1 {
|
||||
if !map2[k] {
|
||||
t.Errorf("Keyword %q present in first call but not in second", k)
|
||||
}
|
||||
}
|
||||
for k := range map2 {
|
||||
if !map1[k] {
|
||||
t.Errorf("Keyword %q present in second call but not in first", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
53
pkg/readers/doc.go
Normal file
53
pkg/readers/doc.go
Normal file
@@ -0,0 +1,53 @@
|
||||
// Package readers provides interfaces and implementations for reading database schemas
|
||||
// from various input formats and data sources.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The readers package defines a common Reader interface that all format-specific readers
|
||||
// implement. This allows RelSpec to read database schemas from multiple sources including:
|
||||
// - Live databases (PostgreSQL, SQLite)
|
||||
// - Schema definition files (DBML, DCTX, DrawDB, GraphQL)
|
||||
// - ORM model files (GORM, Bun, Drizzle, Prisma, TypeORM)
|
||||
// - Data interchange formats (JSON, YAML)
|
||||
//
|
||||
// # Architecture
|
||||
//
|
||||
// Each reader implementation is located in its own subpackage (e.g., pkg/readers/dbml,
|
||||
// pkg/readers/pgsql) and implements the Reader interface, supporting three levels of
|
||||
// granularity:
|
||||
// - ReadDatabase() - Read complete database with all schemas
|
||||
// - ReadSchema() - Read single schema with all tables
|
||||
// - ReadTable() - Read single table with all columns and metadata
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// Readers are instantiated with ReaderOptions containing source-specific configuration:
|
||||
//
|
||||
// // Read from file
|
||||
// reader := dbml.NewReader(&readers.ReaderOptions{
|
||||
// FilePath: "schema.dbml",
|
||||
// })
|
||||
// db, err := reader.ReadDatabase()
|
||||
//
|
||||
// // Read from database
|
||||
// reader := pgsql.NewReader(&readers.ReaderOptions{
|
||||
// ConnectionString: "postgres://user:pass@localhost/mydb",
|
||||
// })
|
||||
// db, err := reader.ReadDatabase()
|
||||
//
|
||||
// # Supported Formats
|
||||
//
|
||||
// - dbml: Database Markup Language files
|
||||
// - dctx: DCTX schema files
|
||||
// - drawdb: DrawDB JSON format
|
||||
// - graphql: GraphQL schema definition language
|
||||
// - json: JSON database schema
|
||||
// - yaml: YAML database schema
|
||||
// - gorm: Go GORM model structs
|
||||
// - bun: Go Bun model structs
|
||||
// - drizzle: TypeScript Drizzle ORM schemas
|
||||
// - prisma: Prisma schema language
|
||||
// - typeorm: TypeScript TypeORM entities
|
||||
// - pgsql: PostgreSQL live database introspection
|
||||
// - sqlite: SQLite database files
|
||||
package readers
|
||||
91
pkg/readers/mssql/README.md
Normal file
91
pkg/readers/mssql/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# MSSQL Reader
|
||||
|
||||
Reads database schema from Microsoft SQL Server databases using a live connection.
|
||||
|
||||
## Features
|
||||
|
||||
- **Live Connection**: Connects to MSSQL databases using the Microsoft ODBC driver
|
||||
- **Multi-Schema Support**: Reads multiple schemas with full support for user-defined schemas
|
||||
- **Comprehensive Metadata**: Reads tables, columns, constraints, indexes, and extended properties
|
||||
- **Type Mapping**: Converts MSSQL types to canonical types for cross-database compatibility
|
||||
- **Extended Properties**: Extracts table and column descriptions from MS_Description
|
||||
- **Identity Columns**: Maps IDENTITY columns to AutoIncrement
|
||||
- **Relationships**: Derives relationships from foreign key constraints
|
||||
|
||||
## Connection String Format
|
||||
|
||||
```
|
||||
sqlserver://[user[:password]@][host][:port][?query]
|
||||
```
|
||||
|
||||
Examples:
|
||||
```
|
||||
sqlserver://sa:password@localhost/dbname
|
||||
sqlserver://user:pass@192.168.1.100:1433/production
|
||||
sqlserver://localhost/testdb?encrypt=disable
|
||||
```
|
||||
|
||||
## Supported Constraints
|
||||
|
||||
- Primary Keys
|
||||
- Foreign Keys (with ON DELETE and ON UPDATE actions)
|
||||
- Unique Constraints
|
||||
- Check Constraints
|
||||
|
||||
## Type Mappings
|
||||
|
||||
| MSSQL Type | Canonical Type |
|
||||
|------------|----------------|
|
||||
| INT | int |
|
||||
| BIGINT | int64 |
|
||||
| SMALLINT | int16 |
|
||||
| TINYINT | int8 |
|
||||
| BIT | bool |
|
||||
| REAL | float32 |
|
||||
| FLOAT | float64 |
|
||||
| NUMERIC, DECIMAL | decimal |
|
||||
| NVARCHAR, VARCHAR | string |
|
||||
| DATETIME2 | timestamp |
|
||||
| DATETIMEOFFSET | timestamptz |
|
||||
| UNIQUEIDENTIFIER | uuid |
|
||||
| VARBINARY | bytea |
|
||||
| DATE | date |
|
||||
| TIME | time |
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/readers/mssql"
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
|
||||
reader := mssql.NewReader(&readers.ReaderOptions{
|
||||
ConnectionString: "sqlserver://sa:password@localhost/mydb",
|
||||
})
|
||||
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Process schema...
|
||||
for _, schema := range db.Schemas {
|
||||
fmt.Printf("Schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
fmt.Printf(" Table: %s\n", table.Name)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests with:
|
||||
```bash
|
||||
go test ./pkg/readers/mssql/...
|
||||
```
|
||||
|
||||
For integration testing with a live MSSQL database:
|
||||
```bash
|
||||
docker-compose up -d mssql
|
||||
go test -tags=integration ./pkg/readers/mssql/...
|
||||
docker-compose down
|
||||
```
|
||||
416
pkg/readers/mssql/queries.go
Normal file
416
pkg/readers/mssql/queries.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// querySchemas retrieves all user-defined schemas from the database
|
||||
func (r *Reader) querySchemas() ([]*models.Schema, error) {
|
||||
query := `
|
||||
SELECT s.name, ISNULL(ep.value, '') as description
|
||||
FROM sys.schemas s
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = s.schema_id
|
||||
AND ep.minor_id = 0
|
||||
AND ep.class = 3
|
||||
AND ep.name = 'MS_Description'
|
||||
WHERE s.name NOT IN ('dbo', 'guest', 'INFORMATION_SCHEMA', 'sys')
|
||||
ORDER BY s.name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
schemas := make([]*models.Schema, 0)
|
||||
for rows.Next() {
|
||||
var name, description string
|
||||
|
||||
if err := rows.Scan(&name, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schema := models.InitSchema(name)
|
||||
if description != "" {
|
||||
schema.Description = description
|
||||
}
|
||||
|
||||
schemas = append(schemas, schema)
|
||||
}
|
||||
|
||||
// Always include dbo schema if it has tables
|
||||
dboSchema := models.InitSchema("dbo")
|
||||
schemas = append(schemas, dboSchema)
|
||||
|
||||
return schemas, rows.Err()
|
||||
}
|
||||
|
||||
// queryTables retrieves all tables for a given schema
|
||||
func (r *Reader) queryTables(schemaName string) ([]*models.Table, error) {
|
||||
query := `
|
||||
SELECT t.table_schema, t.table_name, ISNULL(ep.value, '') as description
|
||||
FROM information_schema.tables t
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = OBJECT_ID(QUOTENAME(t.table_schema) + '.' + QUOTENAME(t.table_name))
|
||||
AND ep.minor_id = 0
|
||||
AND ep.class = 1
|
||||
AND ep.name = 'MS_Description'
|
||||
WHERE t.table_schema = ? AND t.table_type = 'BASE TABLE'
|
||||
ORDER BY t.table_name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tables := make([]*models.Table, 0)
|
||||
for rows.Next() {
|
||||
var schema, tableName, description string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
table := models.InitTable(tableName, schema)
|
||||
if description != "" {
|
||||
table.Description = description
|
||||
}
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// queryColumns retrieves all columns for tables in a schema
|
||||
// Returns map[schema.table]map[columnName]*Column
|
||||
func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.Column, error) {
|
||||
query := `
|
||||
SELECT
|
||||
c.table_schema,
|
||||
c.table_name,
|
||||
c.column_name,
|
||||
c.ordinal_position,
|
||||
c.column_default,
|
||||
c.is_nullable,
|
||||
c.data_type,
|
||||
c.character_maximum_length,
|
||||
c.numeric_precision,
|
||||
c.numeric_scale,
|
||||
ISNULL(ep.value, '') as description,
|
||||
COLUMNPROPERTY(OBJECT_ID(QUOTENAME(c.table_schema) + '.' + QUOTENAME(c.table_name)), c.column_name, 'IsIdentity') as is_identity
|
||||
FROM information_schema.columns c
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = OBJECT_ID(QUOTENAME(c.table_schema) + '.' + QUOTENAME(c.table_name))
|
||||
AND ep.minor_id = COLUMNPROPERTY(OBJECT_ID(QUOTENAME(c.table_schema) + '.' + QUOTENAME(c.table_name)), c.column_name, 'ColumnId')
|
||||
AND ep.class = 1
|
||||
AND ep.name = 'MS_Description'
|
||||
WHERE c.table_schema = ?
|
||||
ORDER BY c.table_schema, c.table_name, c.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columnsMap := make(map[string]map[string]*models.Column)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, columnName, isNullable, dataType, description string
|
||||
var ordinalPosition int
|
||||
var columnDefault, charMaxLength, numPrecision, numScale, isIdentity *int
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &description, &isIdentity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
column := models.InitColumn(columnName, tableName, schema)
|
||||
column.Type = r.mapDataType(dataType)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
if description != "" {
|
||||
column.Description = description
|
||||
}
|
||||
|
||||
// Check if this is an identity column (auto-increment)
|
||||
if isIdentity != nil && *isIdentity == 1 {
|
||||
column.AutoIncrement = true
|
||||
}
|
||||
|
||||
if charMaxLength != nil && *charMaxLength > 0 {
|
||||
column.Length = *charMaxLength
|
||||
}
|
||||
|
||||
if numPrecision != nil && *numPrecision > 0 {
|
||||
column.Precision = *numPrecision
|
||||
}
|
||||
|
||||
if numScale != nil && *numScale > 0 {
|
||||
column.Scale = *numScale
|
||||
}
|
||||
|
||||
// Create table key
|
||||
tableKey := schema + "." + tableName
|
||||
if columnsMap[tableKey] == nil {
|
||||
columnsMap[tableKey] = make(map[string]*models.Column)
|
||||
}
|
||||
columnsMap[tableKey][columnName] = column
|
||||
}
|
||||
|
||||
return columnsMap, rows.Err()
|
||||
}
|
||||
|
||||
// queryPrimaryKeys retrieves all primary key constraints for a schema
|
||||
// Returns map[schema.table]*Constraint
|
||||
func (r *Reader) queryPrimaryKeys(schemaName string) (map[string]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
i.name as constraint_name,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) as columns
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.is_primary_key = 1
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
|
||||
INNER JOIN sys.columns c ON t.object_id = c.object_id AND ic.column_id = c.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, i.name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
primaryKeys := make(map[string]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, columnsStr string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &columnsStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.PrimaryKeyConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
primaryKeys[tableKey] = constraint
|
||||
}
|
||||
|
||||
return primaryKeys, rows.Err()
|
||||
}
|
||||
|
||||
// queryForeignKeys retrieves all foreign key constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryForeignKeys(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
fk.name as constraint_name,
|
||||
rs.name as referenced_schema,
|
||||
rt.name as referenced_table,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY fkc.constraint_column_id) as columns,
|
||||
STRING_AGG(rc.name, ',') WITHIN GROUP (ORDER BY fkc.constraint_column_id) as referenced_columns,
|
||||
fk.delete_referential_action_desc,
|
||||
fk.update_referential_action_desc
|
||||
FROM sys.foreign_keys fk
|
||||
INNER JOIN sys.tables t ON fk.parent_object_id = t.object_id
|
||||
INNER JOIN sys.tables rt ON fk.referenced_object_id = rt.object_id
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.schemas rs ON rt.schema_id = rs.schema_id
|
||||
INNER JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id
|
||||
INNER JOIN sys.columns c ON fkc.parent_object_id = c.object_id AND fkc.parent_column_id = c.column_id
|
||||
INNER JOIN sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, fk.name, rs.name, rt.name, fk.delete_referential_action_desc, fk.update_referential_action_desc
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
foreignKeys := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, refSchema, refTable, columnsStr, refColumnsStr, deleteAction, updateAction string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &refSchema, &refTable, &columnsStr, &refColumnsStr, &deleteAction, &updateAction); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
refColumns := strings.Split(refColumnsStr, ",")
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
constraint.ReferencedSchema = refSchema
|
||||
constraint.ReferencedTable = refTable
|
||||
constraint.ReferencedColumns = refColumns
|
||||
constraint.OnDelete = strings.ToUpper(deleteAction)
|
||||
constraint.OnUpdate = strings.ToUpper(updateAction)
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
foreignKeys[tableKey] = append(foreignKeys[tableKey], constraint)
|
||||
}
|
||||
|
||||
return foreignKeys, rows.Err()
|
||||
}
|
||||
|
||||
// queryUniqueConstraints retrieves all unique constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryUniqueConstraints(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
i.name as constraint_name,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) as columns
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.is_unique = 1 AND i.is_primary_key = 0
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
|
||||
INNER JOIN sys.columns c ON t.object_id = c.object_id AND ic.column_id = c.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, i.name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
uniqueConstraints := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, columnsStr string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &columnsStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.UniqueConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
uniqueConstraints[tableKey] = append(uniqueConstraints[tableKey], constraint)
|
||||
}
|
||||
|
||||
return uniqueConstraints, rows.Err()
|
||||
}
|
||||
|
||||
// queryCheckConstraints retrieves all check constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryCheckConstraints(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
cc.name as constraint_name,
|
||||
cc.definition
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.check_constraints cc ON t.object_id = cc.parent_object_id
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
WHERE s.name = ?
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
checkConstraints := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, definition string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &definition); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.CheckConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Expression = definition
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
checkConstraints[tableKey] = append(checkConstraints[tableKey], constraint)
|
||||
}
|
||||
|
||||
return checkConstraints, rows.Err()
|
||||
}
|
||||
|
||||
// queryIndexes retrieves all indexes for a schema
|
||||
// Returns map[schema.table][]*Index
|
||||
func (r *Reader) queryIndexes(schemaName string) (map[string][]*models.Index, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
i.name as index_name,
|
||||
i.is_unique,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) as columns
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.is_primary_key = 0 AND i.name IS NOT NULL
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
|
||||
INNER JOIN sys.columns c ON t.object_id = c.object_id AND ic.column_id = c.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, i.name, i.is_unique
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
indexes := make(map[string][]*models.Index)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, indexName, columnsStr string
|
||||
var isUnique int
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &indexName, &isUnique, &columnsStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
|
||||
index := models.InitIndex(indexName, tableName, schema)
|
||||
index.Columns = columns
|
||||
index.Unique = (isUnique == 1)
|
||||
index.Type = "btree" // MSSQL uses btree by default
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
indexes[tableKey] = append(indexes[tableKey], index)
|
||||
}
|
||||
|
||||
return indexes, rows.Err()
|
||||
}
|
||||
266
pkg/readers/mssql/reader.go
Normal file
266
pkg/readers/mssql/reader.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/microsoft/go-mssqldb" // MSSQL driver
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/mssql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
// Reader implements the readers.Reader interface for MSSQL databases
|
||||
type Reader struct {
|
||||
options *readers.ReaderOptions
|
||||
db *sql.DB
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewReader creates a new MSSQL reader
|
||||
func NewReader(options *readers.ReaderOptions) *Reader {
|
||||
return &Reader{
|
||||
options: options,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadDatabase reads the entire database schema from MSSQL
|
||||
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
||||
// Validate connection string
|
||||
if r.options.ConnectionString == "" {
|
||||
return nil, fmt.Errorf("connection string is required")
|
||||
}
|
||||
|
||||
// Connect to the database
|
||||
if err := r.connect(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
defer r.close()
|
||||
|
||||
// Get database name
|
||||
var dbName string
|
||||
err := r.db.QueryRowContext(r.ctx, "SELECT DB_NAME()").Scan(&dbName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get database name: %w", err)
|
||||
}
|
||||
|
||||
// Initialize database model
|
||||
db := models.InitDatabase(dbName)
|
||||
db.DatabaseType = models.MSSQLDatabaseType
|
||||
db.SourceFormat = "mssql"
|
||||
|
||||
// Get MSSQL version
|
||||
var version string
|
||||
err = r.db.QueryRowContext(r.ctx, "SELECT @@VERSION").Scan(&version)
|
||||
if err == nil {
|
||||
db.DatabaseVersion = version
|
||||
}
|
||||
|
||||
// Query all schemas
|
||||
schemas, err := r.querySchemas()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query schemas: %w", err)
|
||||
}
|
||||
|
||||
// Process each schema
|
||||
for _, schema := range schemas {
|
||||
// Query tables for this schema
|
||||
tables, err := r.queryTables(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
schema.Tables = tables
|
||||
|
||||
// Query columns for tables
|
||||
columnsMap, err := r.queryColumns(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Populate table columns
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if cols, exists := columnsMap[tableKey]; exists {
|
||||
table.Columns = cols
|
||||
}
|
||||
}
|
||||
|
||||
// Query primary keys
|
||||
primaryKeys, err := r.queryPrimaryKeys(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query primary keys for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply primary keys to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if pk, exists := primaryKeys[tableKey]; exists {
|
||||
table.Constraints[pk.Name] = pk
|
||||
// Mark columns as primary key and not null
|
||||
for _, colName := range pk.Columns {
|
||||
if col, colExists := table.Columns[colName]; colExists {
|
||||
col.IsPrimaryKey = true
|
||||
col.NotNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query foreign keys
|
||||
foreignKeys, err := r.queryForeignKeys(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query foreign keys for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply foreign keys to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if fks, exists := foreignKeys[tableKey]; exists {
|
||||
for _, fk := range fks {
|
||||
table.Constraints[fk.Name] = fk
|
||||
// Derive relationship from foreign key
|
||||
r.deriveRelationship(table, fk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query unique constraints
|
||||
uniqueConstraints, err := r.queryUniqueConstraints(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query unique constraints for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply unique constraints to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if ucs, exists := uniqueConstraints[tableKey]; exists {
|
||||
for _, uc := range ucs {
|
||||
table.Constraints[uc.Name] = uc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query check constraints
|
||||
checkConstraints, err := r.queryCheckConstraints(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query check constraints for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply check constraints to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if ccs, exists := checkConstraints[tableKey]; exists {
|
||||
for _, cc := range ccs {
|
||||
table.Constraints[cc.Name] = cc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query indexes
|
||||
indexes, err := r.queryIndexes(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query indexes for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply indexes to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if idxs, exists := indexes[tableKey]; exists {
|
||||
for _, idx := range idxs {
|
||||
table.Indexes[idx.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set RefDatabase for schema
|
||||
schema.RefDatabase = db
|
||||
|
||||
// Set RefSchema for tables
|
||||
for _, table := range schema.Tables {
|
||||
table.RefSchema = schema
|
||||
}
|
||||
|
||||
// Add schema to database
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// ReadSchema reads a single schema (returns the first schema from the database)
|
||||
func (r *Reader) ReadSchema() (*models.Schema, error) {
|
||||
db, err := r.ReadDatabase()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(db.Schemas) == 0 {
|
||||
return nil, fmt.Errorf("no schemas found in database")
|
||||
}
|
||||
return db.Schemas[0], nil
|
||||
}
|
||||
|
||||
// ReadTable reads a single table (returns the first table from the first schema)
|
||||
func (r *Reader) ReadTable() (*models.Table, error) {
|
||||
schema, err := r.ReadSchema()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(schema.Tables) == 0 {
|
||||
return nil, fmt.Errorf("no tables found in schema")
|
||||
}
|
||||
return schema.Tables[0], nil
|
||||
}
|
||||
|
||||
// connect establishes a connection to the MSSQL database
|
||||
func (r *Reader) connect() error {
|
||||
db, err := sql.Open("mssql", r.options.ConnectionString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Test connection
|
||||
if err = db.PingContext(r.ctx); err != nil {
|
||||
db.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
r.db = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the database connection
|
||||
func (r *Reader) close() {
|
||||
if r.db != nil {
|
||||
r.db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// mapDataType maps MSSQL data types to canonical types
|
||||
func (r *Reader) mapDataType(mssqlType string) string {
|
||||
return mssql.ConvertMSSQLToCanonical(mssqlType)
|
||||
}
|
||||
|
||||
// deriveRelationship creates a relationship from a foreign key constraint
|
||||
func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint) {
|
||||
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
||||
|
||||
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
||||
relationship.FromTable = table.Name
|
||||
relationship.FromSchema = table.Schema
|
||||
relationship.ToTable = fk.ReferencedTable
|
||||
relationship.ToSchema = fk.ReferencedSchema
|
||||
relationship.ForeignKey = fk.Name
|
||||
|
||||
// Store constraint actions in properties
|
||||
if fk.OnDelete != "" {
|
||||
relationship.Properties["on_delete"] = fk.OnDelete
|
||||
}
|
||||
if fk.OnUpdate != "" {
|
||||
relationship.Properties["on_update"] = fk.OnUpdate
|
||||
}
|
||||
|
||||
table.Relationships[relationshipName] = relationship
|
||||
}
|
||||
86
pkg/readers/mssql/reader_test.go
Normal file
86
pkg/readers/mssql/reader_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/mssql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestMapDataType tests MSSQL type mapping to canonical types
|
||||
func TestMapDataType(t *testing.T) {
|
||||
reader := NewReader(&readers.ReaderOptions{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mssqlType string
|
||||
expectedType string
|
||||
}{
|
||||
{"INT to int", "INT", "int"},
|
||||
{"BIGINT to int64", "BIGINT", "int64"},
|
||||
{"BIT to bool", "BIT", "bool"},
|
||||
{"NVARCHAR to string", "NVARCHAR(255)", "string"},
|
||||
{"DATETIME2 to timestamp", "DATETIME2", "timestamp"},
|
||||
{"DATETIMEOFFSET to timestamptz", "DATETIMEOFFSET", "timestamptz"},
|
||||
{"UNIQUEIDENTIFIER to uuid", "UNIQUEIDENTIFIER", "uuid"},
|
||||
{"FLOAT to float64", "FLOAT", "float64"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.mssqlType)
|
||||
assert.Equal(t, tt.expectedType, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertCanonicalToMSSQL tests canonical to MSSQL type conversion
|
||||
func TestConvertCanonicalToMSSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
canonicalType string
|
||||
expectedMSSQL string
|
||||
}{
|
||||
{"int to INT", "int", "INT"},
|
||||
{"int64 to BIGINT", "int64", "BIGINT"},
|
||||
{"bool to BIT", "bool", "BIT"},
|
||||
{"string to NVARCHAR(255)", "string", "NVARCHAR(255)"},
|
||||
{"text to NVARCHAR(MAX)", "text", "NVARCHAR(MAX)"},
|
||||
{"timestamp to DATETIME2", "timestamp", "DATETIME2"},
|
||||
{"timestamptz to DATETIMEOFFSET", "timestamptz", "DATETIMEOFFSET"},
|
||||
{"uuid to UNIQUEIDENTIFIER", "uuid", "UNIQUEIDENTIFIER"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := mssql.ConvertCanonicalToMSSQL(tt.canonicalType)
|
||||
assert.Equal(t, tt.expectedMSSQL, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertMSSQLToCanonical tests MSSQL to canonical type conversion
|
||||
func TestConvertMSSQLToCanonical(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mssqlType string
|
||||
expectedType string
|
||||
}{
|
||||
{"INT to int", "INT", "int"},
|
||||
{"BIGINT to int64", "BIGINT", "int64"},
|
||||
{"BIT to bool", "BIT", "bool"},
|
||||
{"NVARCHAR with params", "NVARCHAR(255)", "string"},
|
||||
{"DATETIME2 to timestamp", "DATETIME2", "timestamp"},
|
||||
{"DATETIMEOFFSET to timestamptz", "DATETIMEOFFSET", "timestamptz"},
|
||||
{"UNIQUEIDENTIFIER to uuid", "UNIQUEIDENTIFIER", "uuid"},
|
||||
{"VARBINARY to bytea", "VARBINARY(MAX)", "bytea"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := mssql.ConvertMSSQLToCanonical(tt.mssqlType)
|
||||
assert.Equal(t, tt.expectedType, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -231,14 +231,13 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
||||
}
|
||||
|
||||
column := models.InitColumn(columnName, tableName, schema)
|
||||
column.Type = r.mapDataType(dataType, udtName)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
// Check if this is a serial type (has nextval default)
|
||||
hasNextval := false
|
||||
if columnDefault != nil {
|
||||
// Parse default value - remove nextval for sequences
|
||||
defaultVal := *columnDefault
|
||||
if strings.HasPrefix(defaultVal, "nextval") {
|
||||
hasNextval = true
|
||||
column.AutoIncrement = true
|
||||
column.Default = defaultVal
|
||||
} else {
|
||||
@@ -246,6 +245,11 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
||||
}
|
||||
}
|
||||
|
||||
// Map data type, preserving serial types when detected
|
||||
column.Type = r.mapDataType(dataType, udtName, hasNextval)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
if description != nil {
|
||||
column.Description = *description
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package pgsql
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
@@ -259,33 +260,46 @@ func (r *Reader) close() {
|
||||
}
|
||||
|
||||
// mapDataType maps PostgreSQL data types to canonical types
|
||||
func (r *Reader) mapDataType(pgType, udtName string) string {
|
||||
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||
// If the column has a nextval default, it's likely a serial type
|
||||
// Map to the appropriate serial type instead of the base integer type
|
||||
if hasNextval {
|
||||
switch strings.ToLower(pgType) {
|
||||
case "integer", "int", "int4":
|
||||
return "serial"
|
||||
case "bigint", "int8":
|
||||
return "bigserial"
|
||||
case "smallint", "int2":
|
||||
return "smallserial"
|
||||
}
|
||||
}
|
||||
|
||||
// Map common PostgreSQL types
|
||||
typeMap := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int64",
|
||||
"smallint": "int16",
|
||||
"int": "int",
|
||||
"int2": "int16",
|
||||
"int4": "int",
|
||||
"int8": "int64",
|
||||
"serial": "int",
|
||||
"bigserial": "int64",
|
||||
"smallserial": "int16",
|
||||
"numeric": "decimal",
|
||||
"integer": "integer",
|
||||
"bigint": "bigint",
|
||||
"smallint": "smallint",
|
||||
"int": "integer",
|
||||
"int2": "smallint",
|
||||
"int4": "integer",
|
||||
"int8": "bigint",
|
||||
"serial": "serial",
|
||||
"bigserial": "bigserial",
|
||||
"smallserial": "smallserial",
|
||||
"numeric": "numeric",
|
||||
"decimal": "decimal",
|
||||
"real": "float32",
|
||||
"double precision": "float64",
|
||||
"float4": "float32",
|
||||
"float8": "float64",
|
||||
"money": "decimal",
|
||||
"character varying": "string",
|
||||
"varchar": "string",
|
||||
"character": "string",
|
||||
"char": "string",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"bool": "bool",
|
||||
"real": "real",
|
||||
"double precision": "double precision",
|
||||
"float4": "real",
|
||||
"float8": "double precision",
|
||||
"money": "money",
|
||||
"character varying": "varchar",
|
||||
"varchar": "varchar",
|
||||
"character": "char",
|
||||
"char": "char",
|
||||
"text": "text",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"date": "date",
|
||||
"time": "time",
|
||||
"time without time zone": "time",
|
||||
|
||||
@@ -177,20 +177,20 @@ func TestMapDataType(t *testing.T) {
|
||||
udtName string
|
||||
expected string
|
||||
}{
|
||||
{"integer", "int4", "int"},
|
||||
{"bigint", "int8", "int64"},
|
||||
{"smallint", "int2", "int16"},
|
||||
{"character varying", "varchar", "string"},
|
||||
{"text", "text", "string"},
|
||||
{"boolean", "bool", "bool"},
|
||||
{"integer", "int4", "integer"},
|
||||
{"bigint", "int8", "bigint"},
|
||||
{"smallint", "int2", "smallint"},
|
||||
{"character varying", "varchar", "varchar"},
|
||||
{"text", "text", "text"},
|
||||
{"boolean", "bool", "boolean"},
|
||||
{"timestamp without time zone", "timestamp", "timestamp"},
|
||||
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
||||
{"json", "json", "json"},
|
||||
{"jsonb", "jsonb", "jsonb"},
|
||||
{"uuid", "uuid", "uuid"},
|
||||
{"numeric", "numeric", "decimal"},
|
||||
{"real", "float4", "float32"},
|
||||
{"double precision", "float8", "float64"},
|
||||
{"numeric", "numeric", "numeric"},
|
||||
{"real", "float4", "real"},
|
||||
{"double precision", "float8", "double precision"},
|
||||
{"date", "date", "date"},
|
||||
{"time without time zone", "time", "time"},
|
||||
{"bytea", "bytea", "bytea"},
|
||||
@@ -199,12 +199,31 @@ func TestMapDataType(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pgType, func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.pgType, tt.udtName)
|
||||
result := reader.mapDataType(tt.pgType, tt.udtName, false)
|
||||
if result != tt.expected {
|
||||
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test serial type detection with hasNextval=true
|
||||
serialTests := []struct {
|
||||
pgType string
|
||||
expected string
|
||||
}{
|
||||
{"integer", "serial"},
|
||||
{"bigint", "bigserial"},
|
||||
{"smallint", "smallserial"},
|
||||
}
|
||||
|
||||
for _, tt := range serialTests {
|
||||
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.pgType, "", true)
|
||||
if result != tt.expected {
|
||||
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIndexDefinition(t *testing.T) {
|
||||
|
||||
75
pkg/readers/sqlite/README.md
Normal file
75
pkg/readers/sqlite/README.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# SQLite Reader
|
||||
|
||||
Reads database schema from SQLite database files.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/sqlite"
|
||||
)
|
||||
|
||||
// Using file path
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: "path/to/database.db",
|
||||
}
|
||||
|
||||
reader := sqlite.NewReader(options)
|
||||
db, err := reader.ReadDatabase()
|
||||
|
||||
// Or using connection string
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: "path/to/database.db",
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- Reads tables with columns and data types
|
||||
- Reads views with definitions
|
||||
- Reads primary keys
|
||||
- Reads foreign keys with CASCADE actions
|
||||
- Reads indexes (non-auto-generated)
|
||||
- Maps SQLite types to canonical types
|
||||
- Derives relationships from foreign keys
|
||||
|
||||
## SQLite Specifics
|
||||
|
||||
- SQLite doesn't support schemas, creates single "main" schema
|
||||
- Uses pure Go driver (modernc.org/sqlite) - no CGo required
|
||||
- Supports both file path and connection string
|
||||
- Auto-increment detection for INTEGER PRIMARY KEY columns
|
||||
- Foreign keys require `PRAGMA foreign_keys = ON` to be set
|
||||
|
||||
## Example Schema
|
||||
|
||||
```sql
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username VARCHAR(50) NOT NULL UNIQUE,
|
||||
email VARCHAR(100) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE posts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(200) NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
```
|
||||
|
||||
## Type Mappings
|
||||
|
||||
| SQLite Type | Canonical Type |
|
||||
|-------------|---------------|
|
||||
| INTEGER, INT | int |
|
||||
| BIGINT | int64 |
|
||||
| REAL, DOUBLE | float64 |
|
||||
| TEXT, VARCHAR | string |
|
||||
| BLOB | bytea |
|
||||
| BOOLEAN | bool |
|
||||
| DATE | date |
|
||||
| DATETIME, TIMESTAMP | timestamp |
|
||||
306
pkg/readers/sqlite/queries.go
Normal file
306
pkg/readers/sqlite/queries.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// queryTables retrieves all tables from the SQLite database
|
||||
func (r *Reader) queryTables() ([]*models.Table, error) {
|
||||
query := `
|
||||
SELECT name
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table'
|
||||
AND name NOT LIKE 'sqlite_%'
|
||||
ORDER BY name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tables := make([]*models.Table, 0)
|
||||
for rows.Next() {
|
||||
var tableName string
|
||||
|
||||
if err := rows.Scan(&tableName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
table := models.InitTable(tableName, "main")
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// queryViews retrieves all views from the SQLite database
|
||||
func (r *Reader) queryViews() ([]*models.View, error) {
|
||||
query := `
|
||||
SELECT name, sql
|
||||
FROM sqlite_master
|
||||
WHERE type = 'view'
|
||||
ORDER BY name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
views := make([]*models.View, 0)
|
||||
for rows.Next() {
|
||||
var viewName string
|
||||
var sql *string
|
||||
|
||||
if err := rows.Scan(&viewName, &sql); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
view := models.InitView(viewName, "main")
|
||||
if sql != nil {
|
||||
view.Definition = *sql
|
||||
}
|
||||
|
||||
views = append(views, view)
|
||||
}
|
||||
|
||||
return views, rows.Err()
|
||||
}
|
||||
|
||||
// queryColumns retrieves all columns for a given table or view
|
||||
func (r *Reader) queryColumns(tableName string) (map[string]*models.Column, error) {
|
||||
query := fmt.Sprintf("PRAGMA table_info(%s)", tableName)
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns := make(map[string]*models.Column)
|
||||
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, dataType string
|
||||
var notNull, pk int
|
||||
var defaultValue *string
|
||||
|
||||
if err := rows.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &pk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
column := models.InitColumn(name, tableName, "main")
|
||||
column.Type = r.mapDataType(strings.ToUpper(dataType))
|
||||
column.NotNull = (notNull == 1)
|
||||
column.IsPrimaryKey = (pk > 0)
|
||||
column.Sequence = uint(cid + 1)
|
||||
|
||||
if defaultValue != nil {
|
||||
column.Default = *defaultValue
|
||||
}
|
||||
|
||||
// Check for autoincrement (SQLite uses INTEGER PRIMARY KEY AUTOINCREMENT)
|
||||
if pk > 0 && strings.EqualFold(dataType, "INTEGER") {
|
||||
column.AutoIncrement = r.isAutoIncrement(tableName, name)
|
||||
}
|
||||
|
||||
columns[name] = column
|
||||
}
|
||||
|
||||
return columns, rows.Err()
|
||||
}
|
||||
|
||||
// isAutoIncrement checks if a column is autoincrement
|
||||
func (r *Reader) isAutoIncrement(tableName, columnName string) bool {
|
||||
// Check sqlite_sequence table or parse CREATE TABLE statement
|
||||
query := `
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table' AND name = ?
|
||||
`
|
||||
|
||||
var sql string
|
||||
err := r.db.QueryRowContext(r.ctx, query, tableName).Scan(&sql)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the SQL contains AUTOINCREMENT for this column
|
||||
return strings.Contains(strings.ToUpper(sql), strings.ToUpper(columnName)+" INTEGER PRIMARY KEY AUTOINCREMENT") ||
|
||||
strings.Contains(strings.ToUpper(sql), strings.ToUpper(columnName)+" INTEGER AUTOINCREMENT")
|
||||
}
|
||||
|
||||
// queryPrimaryKey retrieves the primary key constraint for a table
|
||||
func (r *Reader) queryPrimaryKey(tableName string) (*models.Constraint, error) {
|
||||
query := fmt.Sprintf("PRAGMA table_info(%s)", tableName)
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pkColumns []string
|
||||
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, dataType string
|
||||
var notNull, pk int
|
||||
var defaultValue *string
|
||||
|
||||
if err := rows.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &pk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pk > 0 {
|
||||
pkColumns = append(pkColumns, name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pkColumns) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create primary key constraint
|
||||
constraintName := fmt.Sprintf("%s_pkey", tableName)
|
||||
constraint := models.InitConstraint(constraintName, models.PrimaryKeyConstraint)
|
||||
constraint.Schema = "main"
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = pkColumns
|
||||
|
||||
return constraint, rows.Err()
|
||||
}
|
||||
|
||||
// queryForeignKeys retrieves all foreign key constraints for a table
|
||||
func (r *Reader) queryForeignKeys(tableName string) ([]*models.Constraint, error) {
|
||||
query := fmt.Sprintf("PRAGMA foreign_key_list(%s)", tableName)
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Group foreign keys by id (since composite FKs have multiple rows)
|
||||
fkMap := make(map[int]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var id, seq int
|
||||
var referencedTable, fromColumn, toColumn string
|
||||
var onUpdate, onDelete, match string
|
||||
|
||||
if err := rows.Scan(&id, &seq, &referencedTable, &fromColumn, &toColumn, &onUpdate, &onDelete, &match); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, exists := fkMap[id]; !exists {
|
||||
constraintName := fmt.Sprintf("%s_%s_fkey", tableName, referencedTable)
|
||||
if id > 0 {
|
||||
constraintName = fmt.Sprintf("%s_%s_fkey_%d", tableName, referencedTable, id)
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint)
|
||||
constraint.Schema = "main"
|
||||
constraint.Table = tableName
|
||||
constraint.ReferencedSchema = "main"
|
||||
constraint.ReferencedTable = referencedTable
|
||||
constraint.OnUpdate = onUpdate
|
||||
constraint.OnDelete = onDelete
|
||||
constraint.Columns = []string{}
|
||||
constraint.ReferencedColumns = []string{}
|
||||
|
||||
fkMap[id] = constraint
|
||||
}
|
||||
|
||||
// Add column to the constraint
|
||||
fkMap[id].Columns = append(fkMap[id].Columns, fromColumn)
|
||||
fkMap[id].ReferencedColumns = append(fkMap[id].ReferencedColumns, toColumn)
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
foreignKeys := make([]*models.Constraint, 0, len(fkMap))
|
||||
for _, fk := range fkMap {
|
||||
foreignKeys = append(foreignKeys, fk)
|
||||
}
|
||||
|
||||
return foreignKeys, rows.Err()
|
||||
}
|
||||
|
||||
// queryIndexes retrieves all indexes for a table
|
||||
func (r *Reader) queryIndexes(tableName string) ([]*models.Index, error) {
|
||||
query := fmt.Sprintf("PRAGMA index_list(%s)", tableName)
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
indexes := make([]*models.Index, 0)
|
||||
|
||||
for rows.Next() {
|
||||
var seq int
|
||||
var name string
|
||||
var unique int
|
||||
var origin string
|
||||
var partial int
|
||||
|
||||
if err := rows.Scan(&seq, &name, &unique, &origin, &partial); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Skip auto-generated indexes (origin = 'pk' for primary keys, etc.)
|
||||
// origin: c = CREATE INDEX, u = UNIQUE constraint, pk = PRIMARY KEY
|
||||
if origin == "pk" || origin == "u" {
|
||||
continue
|
||||
}
|
||||
|
||||
index := models.InitIndex(name, tableName, "main")
|
||||
index.Unique = (unique == 1)
|
||||
|
||||
// Get index columns
|
||||
columns, err := r.queryIndexColumns(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
index.Columns = columns
|
||||
|
||||
indexes = append(indexes, index)
|
||||
}
|
||||
|
||||
return indexes, rows.Err()
|
||||
}
|
||||
|
||||
// queryIndexColumns retrieves the columns for a specific index
|
||||
func (r *Reader) queryIndexColumns(indexName string) ([]string, error) {
|
||||
query := fmt.Sprintf("PRAGMA index_info(%s)", indexName)
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns := make([]string, 0)
|
||||
|
||||
for rows.Next() {
|
||||
var seqno, cid int
|
||||
var name *string
|
||||
|
||||
if err := rows.Scan(&seqno, &cid, &name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if name != nil {
|
||||
columns = append(columns, *name)
|
||||
}
|
||||
}
|
||||
|
||||
return columns, rows.Err()
|
||||
}
|
||||
261
pkg/readers/sqlite/reader.go
Normal file
261
pkg/readers/sqlite/reader.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
_ "modernc.org/sqlite" // SQLite driver
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
// Reader implements the readers.Reader interface for SQLite databases
|
||||
type Reader struct {
|
||||
options *readers.ReaderOptions
|
||||
db *sql.DB
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewReader creates a new SQLite reader
|
||||
func NewReader(options *readers.ReaderOptions) *Reader {
|
||||
return &Reader{
|
||||
options: options,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadDatabase reads the entire database schema from SQLite
|
||||
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
||||
// Validate file path or connection string
|
||||
dbPath := r.options.FilePath
|
||||
if dbPath == "" && r.options.ConnectionString != "" {
|
||||
dbPath = r.options.ConnectionString
|
||||
}
|
||||
if dbPath == "" {
|
||||
return nil, fmt.Errorf("file path or connection string is required")
|
||||
}
|
||||
|
||||
// Connect to the database
|
||||
if err := r.connect(dbPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
defer r.close()
|
||||
|
||||
// Get database name from file path
|
||||
dbName := filepath.Base(dbPath)
|
||||
if dbName == "" {
|
||||
dbName = "sqlite"
|
||||
}
|
||||
|
||||
// Initialize database model
|
||||
db := models.InitDatabase(dbName)
|
||||
db.DatabaseType = models.SqlLiteDatabaseType
|
||||
db.SourceFormat = "sqlite"
|
||||
|
||||
// Get SQLite version
|
||||
var version string
|
||||
err := r.db.QueryRowContext(r.ctx, "SELECT sqlite_version()").Scan(&version)
|
||||
if err == nil {
|
||||
db.DatabaseVersion = version
|
||||
}
|
||||
|
||||
// SQLite doesn't have schemas, so we create a single "main" schema
|
||||
schema := models.InitSchema("main")
|
||||
schema.RefDatabase = db
|
||||
|
||||
// Query tables
|
||||
tables, err := r.queryTables()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
schema.Tables = tables
|
||||
|
||||
// Query views
|
||||
views, err := r.queryViews()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query views: %w", err)
|
||||
}
|
||||
schema.Views = views
|
||||
|
||||
// Query columns for tables and views
|
||||
for _, table := range schema.Tables {
|
||||
columns, err := r.queryColumns(table.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns for table %s: %w", table.Name, err)
|
||||
}
|
||||
table.Columns = columns
|
||||
table.RefSchema = schema
|
||||
|
||||
// Query primary key
|
||||
pk, err := r.queryPrimaryKey(table.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query primary key for table %s: %w", table.Name, err)
|
||||
}
|
||||
if pk != nil {
|
||||
table.Constraints[pk.Name] = pk
|
||||
// Mark columns as primary key and not null
|
||||
for _, colName := range pk.Columns {
|
||||
if col, exists := table.Columns[colName]; exists {
|
||||
col.IsPrimaryKey = true
|
||||
col.NotNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query foreign keys
|
||||
foreignKeys, err := r.queryForeignKeys(table.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query foreign keys for table %s: %w", table.Name, err)
|
||||
}
|
||||
for _, fk := range foreignKeys {
|
||||
table.Constraints[fk.Name] = fk
|
||||
// Derive relationship from foreign key
|
||||
r.deriveRelationship(table, fk)
|
||||
}
|
||||
|
||||
// Query indexes
|
||||
indexes, err := r.queryIndexes(table.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query indexes for table %s: %w", table.Name, err)
|
||||
}
|
||||
for _, idx := range indexes {
|
||||
table.Indexes[idx.Name] = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Query columns for views
|
||||
for _, view := range schema.Views {
|
||||
columns, err := r.queryColumns(view.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns for view %s: %w", view.Name, err)
|
||||
}
|
||||
view.Columns = columns
|
||||
view.RefSchema = schema
|
||||
}
|
||||
|
||||
// Add schema to database
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// ReadSchema reads a single schema (returns the main schema from the database)
|
||||
func (r *Reader) ReadSchema() (*models.Schema, error) {
|
||||
db, err := r.ReadDatabase()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(db.Schemas) == 0 {
|
||||
return nil, fmt.Errorf("no schemas found in database")
|
||||
}
|
||||
return db.Schemas[0], nil
|
||||
}
|
||||
|
||||
// ReadTable reads a single table (returns the first table from the schema)
|
||||
func (r *Reader) ReadTable() (*models.Table, error) {
|
||||
schema, err := r.ReadSchema()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(schema.Tables) == 0 {
|
||||
return nil, fmt.Errorf("no tables found in schema")
|
||||
}
|
||||
return schema.Tables[0], nil
|
||||
}
|
||||
|
||||
// connect establishes a connection to the SQLite database
|
||||
func (r *Reader) connect(dbPath string) error {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.db = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the database connection
|
||||
func (r *Reader) close() {
|
||||
if r.db != nil {
|
||||
r.db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// mapDataType maps SQLite data types to canonical types
|
||||
func (r *Reader) mapDataType(sqliteType string) string {
|
||||
// SQLite has a flexible type system, but we map common types
|
||||
typeMap := map[string]string{
|
||||
"INTEGER": "int",
|
||||
"INT": "int",
|
||||
"TINYINT": "int8",
|
||||
"SMALLINT": "int16",
|
||||
"MEDIUMINT": "int",
|
||||
"BIGINT": "int64",
|
||||
"UNSIGNED BIG INT": "uint64",
|
||||
"INT2": "int16",
|
||||
"INT8": "int64",
|
||||
"REAL": "float64",
|
||||
"DOUBLE": "float64",
|
||||
"DOUBLE PRECISION": "float64",
|
||||
"FLOAT": "float32",
|
||||
"NUMERIC": "decimal",
|
||||
"DECIMAL": "decimal",
|
||||
"BOOLEAN": "bool",
|
||||
"BOOL": "bool",
|
||||
"DATE": "date",
|
||||
"DATETIME": "timestamp",
|
||||
"TIMESTAMP": "timestamp",
|
||||
"TEXT": "string",
|
||||
"VARCHAR": "string",
|
||||
"CHAR": "string",
|
||||
"CHARACTER": "string",
|
||||
"VARYING CHARACTER": "string",
|
||||
"NCHAR": "string",
|
||||
"NVARCHAR": "string",
|
||||
"CLOB": "text",
|
||||
"BLOB": "bytea",
|
||||
}
|
||||
|
||||
// Try exact match first
|
||||
if mapped, exists := typeMap[sqliteType]; exists {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// Try case-insensitive match for common types
|
||||
sqliteTypeUpper := sqliteType
|
||||
if len(sqliteType) > 0 {
|
||||
// Extract base type (e.g., "VARCHAR(255)" -> "VARCHAR")
|
||||
for baseType := range typeMap {
|
||||
if len(sqliteTypeUpper) >= len(baseType) && sqliteTypeUpper[:len(baseType)] == baseType {
|
||||
return typeMap[baseType]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to string for unknown types
|
||||
return "string"
|
||||
}
|
||||
|
||||
// deriveRelationship creates a relationship from a foreign key constraint
|
||||
func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint) {
|
||||
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
||||
|
||||
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
||||
relationship.FromTable = table.Name
|
||||
relationship.FromSchema = table.Schema
|
||||
relationship.ToTable = fk.ReferencedTable
|
||||
relationship.ToSchema = fk.ReferencedSchema
|
||||
relationship.ForeignKey = fk.Name
|
||||
|
||||
// Store constraint actions in properties
|
||||
if fk.OnDelete != "" {
|
||||
relationship.Properties["on_delete"] = fk.OnDelete
|
||||
}
|
||||
if fk.OnUpdate != "" {
|
||||
relationship.Properties["on_update"] = fk.OnUpdate
|
||||
}
|
||||
|
||||
table.Relationships[relationshipName] = relationship
|
||||
}
|
||||
334
pkg/readers/sqlite/reader_test.go
Normal file
334
pkg/readers/sqlite/reader_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
// setupTestDatabase creates a temporary SQLite database with test data
|
||||
func setupTestDatabase(t *testing.T) string {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create test schema
|
||||
schema := `
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username VARCHAR(50) NOT NULL UNIQUE,
|
||||
email VARCHAR(100) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE posts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(200) NOT NULL,
|
||||
content TEXT,
|
||||
published BOOLEAN DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE comments (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
post_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
comment TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (post_id) REFERENCES posts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX idx_posts_user_id ON posts(user_id);
|
||||
CREATE INDEX idx_comments_post_id ON comments(post_id);
|
||||
CREATE UNIQUE INDEX idx_users_email ON users(email);
|
||||
|
||||
CREATE VIEW user_post_count AS
|
||||
SELECT u.id, u.username, COUNT(p.id) as post_count
|
||||
FROM users u
|
||||
LEFT JOIN posts p ON u.id = p.user_id
|
||||
GROUP BY u.id, u.username;
|
||||
`
|
||||
|
||||
_, err = db.Exec(schema)
|
||||
require.NoError(t, err)
|
||||
|
||||
return dbPath
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase(t *testing.T) {
|
||||
dbPath := setupTestDatabase(t)
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: dbPath,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
db, err := reader.ReadDatabase()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
// Check database metadata
|
||||
assert.Equal(t, "test.db", db.Name)
|
||||
assert.Equal(t, models.SqlLiteDatabaseType, db.DatabaseType)
|
||||
assert.Equal(t, "sqlite", db.SourceFormat)
|
||||
assert.NotEmpty(t, db.DatabaseVersion)
|
||||
|
||||
// Check schemas (SQLite should have a single "main" schema)
|
||||
require.Len(t, db.Schemas, 1)
|
||||
schema := db.Schemas[0]
|
||||
assert.Equal(t, "main", schema.Name)
|
||||
|
||||
// Check tables
|
||||
assert.Len(t, schema.Tables, 3)
|
||||
tableNames := make([]string, len(schema.Tables))
|
||||
for i, table := range schema.Tables {
|
||||
tableNames[i] = table.Name
|
||||
}
|
||||
assert.Contains(t, tableNames, "users")
|
||||
assert.Contains(t, tableNames, "posts")
|
||||
assert.Contains(t, tableNames, "comments")
|
||||
|
||||
// Check views
|
||||
assert.Len(t, schema.Views, 1)
|
||||
assert.Equal(t, "user_post_count", schema.Views[0].Name)
|
||||
assert.NotEmpty(t, schema.Views[0].Definition)
|
||||
}
|
||||
|
||||
func TestReader_ReadTable_Users(t *testing.T) {
|
||||
dbPath := setupTestDatabase(t)
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: dbPath,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
db, err := reader.ReadDatabase()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
// Find users table
|
||||
var usersTable *models.Table
|
||||
for _, table := range db.Schemas[0].Tables {
|
||||
if table.Name == "users" {
|
||||
usersTable = table
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, usersTable)
|
||||
assert.Equal(t, "users", usersTable.Name)
|
||||
assert.Equal(t, "main", usersTable.Schema)
|
||||
|
||||
// Check columns
|
||||
assert.Len(t, usersTable.Columns, 4)
|
||||
|
||||
// Check id column
|
||||
idCol, exists := usersTable.Columns["id"]
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, "int", idCol.Type)
|
||||
assert.True(t, idCol.IsPrimaryKey)
|
||||
assert.True(t, idCol.AutoIncrement)
|
||||
assert.True(t, idCol.NotNull)
|
||||
|
||||
// Check username column
|
||||
usernameCol, exists := usersTable.Columns["username"]
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, "string", usernameCol.Type)
|
||||
assert.True(t, usernameCol.NotNull)
|
||||
assert.False(t, usernameCol.IsPrimaryKey)
|
||||
|
||||
// Check email column
|
||||
emailCol, exists := usersTable.Columns["email"]
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, "string", emailCol.Type)
|
||||
assert.True(t, emailCol.NotNull)
|
||||
|
||||
// Check primary key constraint
|
||||
assert.Len(t, usersTable.Constraints, 1)
|
||||
pkConstraint, exists := usersTable.Constraints["users_pkey"]
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, models.PrimaryKeyConstraint, pkConstraint.Type)
|
||||
assert.Equal(t, []string{"id"}, pkConstraint.Columns)
|
||||
|
||||
// Check indexes (should have unique index on email and username)
|
||||
assert.GreaterOrEqual(t, len(usersTable.Indexes), 1)
|
||||
}
|
||||
|
||||
func TestReader_ReadTable_Posts(t *testing.T) {
|
||||
dbPath := setupTestDatabase(t)
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: dbPath,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
db, err := reader.ReadDatabase()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
// Find posts table
|
||||
var postsTable *models.Table
|
||||
for _, table := range db.Schemas[0].Tables {
|
||||
if table.Name == "posts" {
|
||||
postsTable = table
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, postsTable)
|
||||
|
||||
// Check columns
|
||||
assert.Len(t, postsTable.Columns, 6)
|
||||
|
||||
// Check foreign key constraint
|
||||
hasForeignKey := false
|
||||
for _, constraint := range postsTable.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
hasForeignKey = true
|
||||
assert.Equal(t, "users", constraint.ReferencedTable)
|
||||
assert.Equal(t, "CASCADE", constraint.OnDelete)
|
||||
}
|
||||
}
|
||||
assert.True(t, hasForeignKey, "Posts table should have a foreign key constraint")
|
||||
|
||||
// Check relationships
|
||||
assert.GreaterOrEqual(t, len(postsTable.Relationships), 1)
|
||||
|
||||
// Check indexes
|
||||
hasUserIdIndex := false
|
||||
for _, index := range postsTable.Indexes {
|
||||
if index.Name == "idx_posts_user_id" {
|
||||
hasUserIdIndex = true
|
||||
assert.Contains(t, index.Columns, "user_id")
|
||||
}
|
||||
}
|
||||
assert.True(t, hasUserIdIndex, "Posts table should have idx_posts_user_id index")
|
||||
}
|
||||
|
||||
func TestReader_ReadTable_Comments(t *testing.T) {
|
||||
dbPath := setupTestDatabase(t)
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: dbPath,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
db, err := reader.ReadDatabase()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
// Find comments table
|
||||
var commentsTable *models.Table
|
||||
for _, table := range db.Schemas[0].Tables {
|
||||
if table.Name == "comments" {
|
||||
commentsTable = table
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, commentsTable)
|
||||
|
||||
// Check foreign key constraints (should have 2)
|
||||
fkCount := 0
|
||||
for _, constraint := range commentsTable.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
fkCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, fkCount, "Comments table should have 2 foreign key constraints")
|
||||
}
|
||||
|
||||
func TestReader_ReadSchema(t *testing.T) {
|
||||
dbPath := setupTestDatabase(t)
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: dbPath,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
schema, err := reader.ReadSchema()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, schema)
|
||||
assert.Equal(t, "main", schema.Name)
|
||||
assert.Len(t, schema.Tables, 3)
|
||||
assert.Len(t, schema.Views, 1)
|
||||
}
|
||||
|
||||
func TestReader_ReadTable(t *testing.T) {
|
||||
dbPath := setupTestDatabase(t)
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: dbPath,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
table, err := reader.ReadTable()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, table)
|
||||
assert.NotEmpty(t, table.Name)
|
||||
assert.NotEmpty(t, table.Columns)
|
||||
}
|
||||
|
||||
func TestReader_ConnectionString(t *testing.T) {
|
||||
dbPath := setupTestDatabase(t)
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: dbPath,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
db, err := reader.ReadDatabase()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
assert.Len(t, db.Schemas, 1)
|
||||
}
|
||||
|
||||
func TestReader_InvalidPath(t *testing.T) {
|
||||
options := &readers.ReaderOptions{
|
||||
FilePath: "/nonexistent/path/to/database.db",
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
_, err := reader.ReadDatabase()
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReader_MissingPath(t *testing.T) {
|
||||
options := &readers.ReaderOptions{}
|
||||
|
||||
reader := NewReader(options)
|
||||
_, err := reader.ReadDatabase()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file path or connection string is required")
|
||||
}
|
||||
36
pkg/reflectutil/doc.go
Normal file
36
pkg/reflectutil/doc.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Package reflectutil provides reflection utilities for analyzing Go code structures.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The reflectutil package offers helper functions for working with Go's reflection
|
||||
// capabilities, particularly for parsing Go struct definitions and extracting type
|
||||
// information. This is used by readers that parse ORM model files.
|
||||
//
|
||||
// # Features
|
||||
//
|
||||
// - Struct tag parsing and extraction
|
||||
// - Type information analysis
|
||||
// - Field metadata extraction
|
||||
// - ORM tag interpretation (GORM, Bun, etc.)
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// This package is primarily used internally by readers like GORM and Bun to parse
|
||||
// Go struct definitions and convert them to database schema models.
|
||||
//
|
||||
// // Example: Parse struct tags
|
||||
// tags := reflectutil.ParseStructTags(field)
|
||||
// columnName := tags.Get("db")
|
||||
//
|
||||
// # Supported ORM Tags
|
||||
//
|
||||
// The package understands tag conventions from:
|
||||
// - GORM (gorm tag)
|
||||
// - Bun (bun tag)
|
||||
// - Standard database/sql (db tag)
|
||||
//
|
||||
// # Purpose
|
||||
//
|
||||
// This package enables RelSpec to read existing ORM models and convert them to
|
||||
// a unified schema representation for transformation to other formats.
|
||||
package reflectutil
|
||||
490
pkg/reflectutil/helpers_test.go
Normal file
490
pkg/reflectutil/helpers_test.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package reflectutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
Active bool
|
||||
Nested *nestedStruct
|
||||
Private string
|
||||
}
|
||||
|
||||
type nestedStruct struct {
|
||||
Value string
|
||||
Count int
|
||||
}
|
||||
|
||||
func TestDeref(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantValid bool
|
||||
wantKind reflect.Kind
|
||||
}{
|
||||
{
|
||||
name: "non-pointer int",
|
||||
input: 42,
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "single pointer",
|
||||
input: ptrInt(42),
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "double pointer",
|
||||
input: ptrPtr(ptrInt(42)),
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: (*int)(nil),
|
||||
wantValid: false,
|
||||
wantKind: reflect.Ptr,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
input: "test",
|
||||
wantValid: true,
|
||||
wantKind: reflect.String,
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
input: testStruct{Name: "test"},
|
||||
wantValid: true,
|
||||
wantKind: reflect.Struct,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := reflect.ValueOf(tt.input)
|
||||
got, valid := Deref(v)
|
||||
|
||||
if valid != tt.wantValid {
|
||||
t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid)
|
||||
}
|
||||
|
||||
if got.Kind() != tt.wantKind {
|
||||
t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDerefInterface(t *testing.T) {
|
||||
i := 42
|
||||
pi := &i
|
||||
ppi := &pi
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantKind reflect.Kind
|
||||
}{
|
||||
{"int", 42, reflect.Int},
|
||||
{"pointer to int", &i, reflect.Int},
|
||||
{"double pointer to int", ppi, reflect.Int},
|
||||
{"string", "test", reflect.String},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DerefInterface(tt.input)
|
||||
if got.Kind() != tt.wantKind {
|
||||
t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFieldValue(t *testing.T) {
|
||||
ts := testStruct{
|
||||
Name: "John",
|
||||
Age: 30,
|
||||
Active: true,
|
||||
Nested: &nestedStruct{Value: "nested", Count: 5},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
item interface{}
|
||||
field string
|
||||
want interface{}
|
||||
}{
|
||||
{"struct field Name", ts, "Name", "John"},
|
||||
{"struct field Age", ts, "Age", 30},
|
||||
{"struct field Active", ts, "Active", true},
|
||||
{"struct non-existent field", ts, "NonExistent", nil},
|
||||
{"pointer to struct", &ts, "Name", "John"},
|
||||
{"map string key", map[string]string{"key": "value"}, "key", "value"},
|
||||
{"map int key", map[string]int{"count": 42}, "count", 42},
|
||||
{"map non-existent key", map[string]string{"key": "value"}, "missing", nil},
|
||||
{"nil pointer", (*testStruct)(nil), "Name", nil},
|
||||
{"non-struct non-map", 42, "field", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetFieldValue(tt.item, tt.field)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetFieldValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSliceOrArray(t *testing.T) {
|
||||
arr := [3]int{1, 2, 3}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want bool
|
||||
}{
|
||||
{"slice", []int{1, 2, 3}, true},
|
||||
{"array", arr, true},
|
||||
{"pointer to slice", &[]int{1, 2, 3}, true},
|
||||
{"string", "test", false},
|
||||
{"int", 42, false},
|
||||
{"map", map[string]int{}, false},
|
||||
{"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice
|
||||
{"nil pointer", (*[]int)(nil), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsSliceOrArray(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want bool
|
||||
}{
|
||||
{"map[string]int", map[string]int{"a": 1}, true},
|
||||
{"map[int]string", map[int]string{1: "a"}, true},
|
||||
{"pointer to map", &map[string]int{"a": 1}, true},
|
||||
{"slice", []int{1, 2, 3}, false},
|
||||
{"string", "test", false},
|
||||
{"int", 42, false},
|
||||
{"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map
|
||||
{"nil pointer", (*map[string]int)(nil), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsMap(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsMap() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceLen(t *testing.T) {
|
||||
arr := [3]int{1, 2, 3}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int
|
||||
}{
|
||||
{"slice length 3", []int{1, 2, 3}, 3},
|
||||
{"empty slice", []int{}, 0},
|
||||
{"array length 3", arr, 3},
|
||||
{"pointer to slice", &[]int{1, 2, 3}, 3},
|
||||
{"not a slice", "test", 0},
|
||||
{"int", 42, 0},
|
||||
{"nil slice", ([]int)(nil), 0},
|
||||
{"nil pointer", (*[]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceLen(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("SliceLen() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int
|
||||
}{
|
||||
{"map length 2", map[string]int{"a": 1, "b": 2}, 2},
|
||||
{"empty map", map[string]int{}, 0},
|
||||
{"pointer to map", &map[string]int{"a": 1}, 1},
|
||||
{"not a map", []int{1, 2, 3}, 0},
|
||||
{"string", "test", 0},
|
||||
{"nil map", (map[string]int)(nil), 0},
|
||||
{"nil pointer", (*map[string]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapLen(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapLen() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceToInterfaces(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want []interface{}
|
||||
}{
|
||||
{"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}},
|
||||
{"string slice", []string{"a", "b"}, []interface{}{"a", "b"}},
|
||||
{"empty slice", []int{}, []interface{}{}},
|
||||
{"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}},
|
||||
{"not a slice", "test", []interface{}{}},
|
||||
{"nil slice", ([]int)(nil), []interface{}{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceToInterfaces(tt.input)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want []interface{}
|
||||
}{
|
||||
{"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}},
|
||||
{"empty map", map[string]int{}, []interface{}{}},
|
||||
{"not a map", []int{1, 2, 3}, []interface{}{}},
|
||||
{"nil map", (map[string]int)(nil), []interface{}{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapKeys(tt.input)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want))
|
||||
}
|
||||
// For maps, order is not guaranteed, so just check length
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int // length of values
|
||||
}{
|
||||
{"map with values", map[string]int{"a": 1, "b": 2}, 2},
|
||||
{"empty map", map[string]int{}, 0},
|
||||
{"not a map", []int{1, 2, 3}, 0},
|
||||
{"nil map", (map[string]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapValues(tt.input)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("MapValues() length = %v, want %v", len(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapGet(t *testing.T) {
|
||||
m := map[string]int{"a": 1, "b": 2}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
key interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"existing key", m, "a", 1},
|
||||
{"existing key b", m, "b", 2},
|
||||
{"non-existing key", m, "c", nil},
|
||||
{"pointer to map", &m, "a", 1},
|
||||
{"not a map", []int{1, 2}, 0, nil},
|
||||
{"nil map", (map[string]int)(nil), "a", nil},
|
||||
{"nil pointer", (*map[string]int)(nil), "a", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapGet(tt.input, tt.key)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("MapGet() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceIndex(t *testing.T) {
|
||||
s := []int{10, 20, 30}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
slice interface{}
|
||||
index int
|
||||
want interface{}
|
||||
}{
|
||||
{"index 0", s, 0, 10},
|
||||
{"index 1", s, 1, 20},
|
||||
{"index 2", s, 2, 30},
|
||||
{"negative index", s, -1, nil},
|
||||
{"out of bounds", s, 5, nil},
|
||||
{"pointer to slice", &s, 1, 20},
|
||||
{"not a slice", "test", 0, nil},
|
||||
{"nil slice", ([]int)(nil), 0, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceIndex(tt.slice, tt.index)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SliceIndex() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a interface{}
|
||||
b interface{}
|
||||
want int
|
||||
}{
|
||||
{"both nil", nil, nil, 0},
|
||||
{"a nil", nil, 5, -1},
|
||||
{"b nil", 5, nil, 1},
|
||||
{"equal strings", "abc", "abc", 0},
|
||||
{"a less than b strings", "abc", "xyz", -1},
|
||||
{"a greater than b strings", "xyz", "abc", 1},
|
||||
{"equal ints", 5, 5, 0},
|
||||
{"a less than b ints", 3, 7, -1},
|
||||
{"a greater than b ints", 10, 5, 1},
|
||||
{"equal floats", 3.14, 3.14, 0},
|
||||
{"a less than b floats", 2.5, 5.5, -1},
|
||||
{"a greater than b floats", 10.5, 5.5, 1},
|
||||
{"equal uints", uint(5), uint(5), 0},
|
||||
{"different types", "abc", 123, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CompareValues(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNestedValue(t *testing.T) {
|
||||
nested := map[string]interface{}{
|
||||
"level1": map[string]interface{}{
|
||||
"level2": map[string]interface{}{
|
||||
"value": "deep",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ts := testStruct{
|
||||
Name: "John",
|
||||
Nested: &nestedStruct{
|
||||
Value: "nested value",
|
||||
Count: 42,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
path string
|
||||
want interface{}
|
||||
}{
|
||||
{"empty path", nested, "", nested},
|
||||
{"single level map", nested, "level1", nested["level1"]},
|
||||
{"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}},
|
||||
{"deep nested map", nested, "level1.level2.value", "deep"},
|
||||
{"struct field", ts, "Name", "John"},
|
||||
{"nested struct field", ts, "Nested", ts.Nested},
|
||||
{"non-existent path", nested, "missing.path", nil},
|
||||
{"nil input", nil, "path", nil},
|
||||
{"partial missing path", nested, "level1.missing", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetNestedValue(tt.input, tt.path)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetNestedValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a interface{}
|
||||
b interface{}
|
||||
want bool
|
||||
}{
|
||||
{"equal ints", 42, 42, true},
|
||||
{"different ints", 42, 43, false},
|
||||
{"equal strings", "test", "test", true},
|
||||
{"different strings", "test", "other", false},
|
||||
{"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true},
|
||||
{"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false},
|
||||
{"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true},
|
||||
{"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false},
|
||||
{"both nil", nil, nil, true},
|
||||
{"one nil", nil, 42, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DeepEqual(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func ptrInt(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
func ptrPtr(p *int) **int {
|
||||
return &p
|
||||
}
|
||||
34
pkg/transform/doc.go
Normal file
34
pkg/transform/doc.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Package transform provides validation and transformation utilities for database models.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The transform package contains a Transformer type that provides methods for validating
|
||||
// and normalizing database schemas. It ensures schema correctness and consistency across
|
||||
// different format conversions.
|
||||
//
|
||||
// # Features
|
||||
//
|
||||
// - Database validation (structure and naming conventions)
|
||||
// - Schema validation (completeness and integrity)
|
||||
// - Table validation (column definitions and constraints)
|
||||
// - Data type normalization
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// transformer := transform.NewTransformer()
|
||||
// err := transformer.ValidateDatabase(db)
|
||||
// if err != nil {
|
||||
// log.Fatal("Invalid database schema:", err)
|
||||
// }
|
||||
//
|
||||
// # Validation Scope
|
||||
//
|
||||
// The transformer validates:
|
||||
// - Required fields presence
|
||||
// - Naming convention adherence
|
||||
// - Data type compatibility
|
||||
// - Constraint consistency
|
||||
// - Relationship integrity
|
||||
//
|
||||
// Note: Some validation methods are currently stubs and will be implemented as needed.
|
||||
package transform
|
||||
57
pkg/ui/doc.go
Normal file
57
pkg/ui/doc.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Package ui provides an interactive terminal user interface (TUI) for editing database schemas.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The ui package implements a full-featured terminal-based schema editor using tview,
|
||||
// allowing users to visually create, modify, and manage database schemas without writing
|
||||
// code or SQL.
|
||||
//
|
||||
// # Features
|
||||
//
|
||||
// The schema editor supports:
|
||||
// - Database management: Edit name, description, and properties
|
||||
// - Schema management: Create, edit, delete schemas
|
||||
// - Table management: Create, edit, delete tables
|
||||
// - Column management: Add, modify, delete columns with full property support
|
||||
// - Relationship management: Define and edit table relationships
|
||||
// - Domain management: Organize tables into logical domains
|
||||
// - Import & merge: Combine schemas from multiple sources
|
||||
// - Save: Export to any supported format
|
||||
//
|
||||
// # Architecture
|
||||
//
|
||||
// The package is organized into several components:
|
||||
// - editor.go: Main editor and application lifecycle
|
||||
// - *_screens.go: UI screens for each entity type
|
||||
// - *_dataops.go: Business logic and data operations
|
||||
// - dialogs.go: Reusable dialog components
|
||||
// - load_save_screens.go: File I/O and format selection
|
||||
// - main_menu.go: Primary navigation menu
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// editor := ui.NewSchemaEditor(database)
|
||||
// if err := editor.Run(); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// Or with pre-configured load/save options:
|
||||
//
|
||||
// editor := ui.NewSchemaEditorWithConfigs(database, loadConfig, saveConfig)
|
||||
// if err := editor.Run(); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// # Navigation
|
||||
//
|
||||
// - Arrow keys: Navigate between items
|
||||
// - Enter: Select/edit item
|
||||
// - Tab/Shift+Tab: Navigate between buttons
|
||||
// - Escape: Go back/cancel
|
||||
// - Letter shortcuts: Quick actions (e.g., 'n' for new, 'e' for edit, 'd' for delete)
|
||||
//
|
||||
// # Integration
|
||||
//
|
||||
// The editor integrates with all readers and writers, supporting load/save operations
|
||||
// for any format supported by RelSpec (DBML, PostgreSQL, GORM, Prisma, etc.).
|
||||
package ui
|
||||
115
pkg/ui/relation_dataops.go
Normal file
115
pkg/ui/relation_dataops.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package ui
|
||||
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
|
||||
// Relationship data operations - business logic for relationship management
|
||||
|
||||
// CreateRelationship creates a new relationship and adds it to a table
|
||||
func (se *SchemaEditor) CreateRelationship(schemaIndex, tableIndex int, rel *models.Relationship) *models.Relationship {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return nil
|
||||
}
|
||||
|
||||
table := schema.Tables[tableIndex]
|
||||
if table.Relationships == nil {
|
||||
table.Relationships = make(map[string]*models.Relationship)
|
||||
}
|
||||
|
||||
table.Relationships[rel.Name] = rel
|
||||
table.UpdateDate()
|
||||
return rel
|
||||
}
|
||||
|
||||
// UpdateRelationship updates an existing relationship
|
||||
func (se *SchemaEditor) UpdateRelationship(schemaIndex, tableIndex int, oldName string, rel *models.Relationship) bool {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return false
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return false
|
||||
}
|
||||
|
||||
table := schema.Tables[tableIndex]
|
||||
if table.Relationships == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Delete old entry if name changed
|
||||
if oldName != rel.Name {
|
||||
delete(table.Relationships, oldName)
|
||||
}
|
||||
|
||||
table.Relationships[rel.Name] = rel
|
||||
table.UpdateDate()
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteRelationship removes a relationship from a table
|
||||
func (se *SchemaEditor) DeleteRelationship(schemaIndex, tableIndex int, relName string) bool {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return false
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return false
|
||||
}
|
||||
|
||||
table := schema.Tables[tableIndex]
|
||||
if table.Relationships == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(table.Relationships, relName)
|
||||
table.UpdateDate()
|
||||
return true
|
||||
}
|
||||
|
||||
// GetRelationship returns a relationship by name
|
||||
func (se *SchemaEditor) GetRelationship(schemaIndex, tableIndex int, relName string) *models.Relationship {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return nil
|
||||
}
|
||||
|
||||
table := schema.Tables[tableIndex]
|
||||
if table.Relationships == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return table.Relationships[relName]
|
||||
}
|
||||
|
||||
// GetRelationshipNames returns all relationship names for a table
|
||||
func (se *SchemaEditor) GetRelationshipNames(schemaIndex, tableIndex int) []string {
|
||||
if schemaIndex < 0 || schemaIndex >= len(se.db.Schemas) {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema := se.db.Schemas[schemaIndex]
|
||||
if tableIndex < 0 || tableIndex >= len(schema.Tables) {
|
||||
return nil
|
||||
}
|
||||
|
||||
table := schema.Tables[tableIndex]
|
||||
if table.Relationships == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(table.Relationships))
|
||||
for name := range table.Relationships {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
486
pkg/ui/relation_screens.go
Normal file
486
pkg/ui/relation_screens.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// showRelationshipList displays all relationships for a table
|
||||
func (se *SchemaEditor) showRelationshipList(schemaIndex, tableIndex int) {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return
|
||||
}
|
||||
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
|
||||
// Title
|
||||
title := tview.NewTextView().
|
||||
SetText(fmt.Sprintf("[::b]Relationships for Table: %s", table.Name)).
|
||||
SetDynamicColors(true).
|
||||
SetTextAlign(tview.AlignCenter)
|
||||
|
||||
// Create relationships table
|
||||
relTable := tview.NewTable().SetBorders(true).SetSelectable(true, false).SetFixed(1, 0)
|
||||
|
||||
// Add header row
|
||||
headers := []string{"Name", "Type", "From Columns", "To Table", "To Columns", "Description"}
|
||||
headerWidths := []int{20, 15, 20, 20, 20}
|
||||
for i, header := range headers {
|
||||
padding := ""
|
||||
if i < len(headerWidths) {
|
||||
padding = strings.Repeat(" ", headerWidths[i]-len(header))
|
||||
}
|
||||
cell := tview.NewTableCell(header + padding).
|
||||
SetTextColor(tcell.ColorYellow).
|
||||
SetSelectable(false).
|
||||
SetAlign(tview.AlignLeft)
|
||||
relTable.SetCell(0, i, cell)
|
||||
}
|
||||
|
||||
// Get relationship names
|
||||
relNames := se.GetRelationshipNames(schemaIndex, tableIndex)
|
||||
for row, relName := range relNames {
|
||||
rel := table.Relationships[relName]
|
||||
|
||||
// Name
|
||||
nameStr := fmt.Sprintf("%-20s", rel.Name)
|
||||
nameCell := tview.NewTableCell(nameStr).SetSelectable(true)
|
||||
relTable.SetCell(row+1, 0, nameCell)
|
||||
|
||||
// Type
|
||||
typeStr := fmt.Sprintf("%-15s", string(rel.Type))
|
||||
typeCell := tview.NewTableCell(typeStr).SetSelectable(true)
|
||||
relTable.SetCell(row+1, 1, typeCell)
|
||||
|
||||
// From Columns
|
||||
fromColsStr := strings.Join(rel.FromColumns, ", ")
|
||||
fromColsStr = fmt.Sprintf("%-20s", fromColsStr)
|
||||
fromColsCell := tview.NewTableCell(fromColsStr).SetSelectable(true)
|
||||
relTable.SetCell(row+1, 2, fromColsCell)
|
||||
|
||||
// To Table
|
||||
toTableStr := rel.ToTable
|
||||
if rel.ToSchema != "" && rel.ToSchema != table.Schema {
|
||||
toTableStr = rel.ToSchema + "." + rel.ToTable
|
||||
}
|
||||
toTableStr = fmt.Sprintf("%-20s", toTableStr)
|
||||
toTableCell := tview.NewTableCell(toTableStr).SetSelectable(true)
|
||||
relTable.SetCell(row+1, 3, toTableCell)
|
||||
|
||||
// To Columns
|
||||
toColsStr := strings.Join(rel.ToColumns, ", ")
|
||||
toColsStr = fmt.Sprintf("%-20s", toColsStr)
|
||||
toColsCell := tview.NewTableCell(toColsStr).SetSelectable(true)
|
||||
relTable.SetCell(row+1, 4, toColsCell)
|
||||
|
||||
// Description
|
||||
descCell := tview.NewTableCell(rel.Description).SetSelectable(true)
|
||||
relTable.SetCell(row+1, 5, descCell)
|
||||
}
|
||||
|
||||
relTable.SetTitle(" Relationships ").SetBorder(true).SetTitleAlign(tview.AlignLeft)
|
||||
|
||||
// Action buttons
|
||||
btnFlex := tview.NewFlex()
|
||||
btnNew := tview.NewButton("New Relationship [n]").SetSelectedFunc(func() {
|
||||
se.showNewRelationshipDialog(schemaIndex, tableIndex)
|
||||
})
|
||||
btnEdit := tview.NewButton("Edit [e]").SetSelectedFunc(func() {
|
||||
row, _ := relTable.GetSelection()
|
||||
if row > 0 && row <= len(relNames) {
|
||||
relName := relNames[row-1]
|
||||
se.showEditRelationshipDialog(schemaIndex, tableIndex, relName)
|
||||
}
|
||||
})
|
||||
btnDelete := tview.NewButton("Delete [d]").SetSelectedFunc(func() {
|
||||
row, _ := relTable.GetSelection()
|
||||
if row > 0 && row <= len(relNames) {
|
||||
relName := relNames[row-1]
|
||||
se.showDeleteRelationshipConfirm(schemaIndex, tableIndex, relName)
|
||||
}
|
||||
})
|
||||
btnBack := tview.NewButton("Back [b]").SetSelectedFunc(func() {
|
||||
se.pages.RemovePage("relationships")
|
||||
se.pages.SwitchToPage("table-editor")
|
||||
})
|
||||
|
||||
// Set up button navigation
|
||||
btnNew.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(relTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnEdit)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnEdit.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnNew)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnDelete)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnDelete.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnEdit)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnBack)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnBack.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnDelete)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(relTable)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnFlex.AddItem(btnNew, 0, 1, true).
|
||||
AddItem(btnEdit, 0, 1, false).
|
||||
AddItem(btnDelete, 0, 1, false).
|
||||
AddItem(btnBack, 0, 1, false)
|
||||
|
||||
relTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("relationships")
|
||||
se.pages.SwitchToPage("table-editor")
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnNew)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyEnter {
|
||||
row, _ := relTable.GetSelection()
|
||||
if row > 0 && row <= len(relNames) {
|
||||
relName := relNames[row-1]
|
||||
se.showEditRelationshipDialog(schemaIndex, tableIndex, relName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'n' {
|
||||
se.showNewRelationshipDialog(schemaIndex, tableIndex)
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'e' {
|
||||
row, _ := relTable.GetSelection()
|
||||
if row > 0 && row <= len(relNames) {
|
||||
relName := relNames[row-1]
|
||||
se.showEditRelationshipDialog(schemaIndex, tableIndex, relName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'd' {
|
||||
row, _ := relTable.GetSelection()
|
||||
if row > 0 && row <= len(relNames) {
|
||||
relName := relNames[row-1]
|
||||
se.showDeleteRelationshipConfirm(schemaIndex, tableIndex, relName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'b' {
|
||||
se.pages.RemovePage("relationships")
|
||||
se.pages.SwitchToPage("table-editor")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
flex.AddItem(title, 1, 0, false).
|
||||
AddItem(relTable, 0, 1, true).
|
||||
AddItem(btnFlex, 1, 0, false)
|
||||
|
||||
se.pages.AddPage("relationships", flex, true, true)
|
||||
}
|
||||
|
||||
// showNewRelationshipDialog shows dialog to create a new relationship
|
||||
func (se *SchemaEditor) showNewRelationshipDialog(schemaIndex, tableIndex int) {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return
|
||||
}
|
||||
|
||||
form := tview.NewForm()
|
||||
|
||||
// Collect all tables for dropdown
|
||||
var allTables []string
|
||||
var tableMap []struct{ schemaIdx, tableIdx int }
|
||||
for si, schema := range se.db.Schemas {
|
||||
for ti, t := range schema.Tables {
|
||||
tableName := t.Name
|
||||
if schema.Name != table.Schema {
|
||||
tableName = schema.Name + "." + t.Name
|
||||
}
|
||||
allTables = append(allTables, tableName)
|
||||
tableMap = append(tableMap, struct{ schemaIdx, tableIdx int }{si, ti})
|
||||
}
|
||||
}
|
||||
|
||||
relName := ""
|
||||
relType := models.OneToMany
|
||||
fromColumns := ""
|
||||
toColumns := ""
|
||||
description := ""
|
||||
selectedTableIdx := 0
|
||||
|
||||
form.AddInputField("Name", "", 40, nil, func(value string) {
|
||||
relName = value
|
||||
})
|
||||
|
||||
form.AddDropDown("Type", []string{
|
||||
string(models.OneToOne),
|
||||
string(models.OneToMany),
|
||||
string(models.ManyToMany),
|
||||
}, 1, func(option string, optionIndex int) {
|
||||
relType = models.RelationType(option)
|
||||
})
|
||||
|
||||
form.AddInputField("From Columns (comma-separated)", "", 40, nil, func(value string) {
|
||||
fromColumns = value
|
||||
})
|
||||
|
||||
form.AddDropDown("To Table", allTables, 0, func(option string, optionIndex int) {
|
||||
selectedTableIdx = optionIndex
|
||||
})
|
||||
|
||||
form.AddInputField("To Columns (comma-separated)", "", 40, nil, func(value string) {
|
||||
toColumns = value
|
||||
})
|
||||
|
||||
form.AddInputField("Description", "", 60, nil, func(value string) {
|
||||
description = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if relName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse columns
|
||||
fromCols := strings.Split(fromColumns, ",")
|
||||
for i := range fromCols {
|
||||
fromCols[i] = strings.TrimSpace(fromCols[i])
|
||||
}
|
||||
|
||||
toCols := strings.Split(toColumns, ",")
|
||||
for i := range toCols {
|
||||
toCols[i] = strings.TrimSpace(toCols[i])
|
||||
}
|
||||
|
||||
// Get target table
|
||||
targetSchema := se.db.Schemas[tableMap[selectedTableIdx].schemaIdx]
|
||||
targetTable := targetSchema.Tables[tableMap[selectedTableIdx].tableIdx]
|
||||
|
||||
rel := models.InitRelationship(relName, relType)
|
||||
rel.FromTable = table.Name
|
||||
rel.FromSchema = table.Schema
|
||||
rel.FromColumns = fromCols
|
||||
rel.ToTable = targetTable.Name
|
||||
rel.ToSchema = targetTable.Schema
|
||||
rel.ToColumns = toCols
|
||||
rel.Description = description
|
||||
|
||||
se.CreateRelationship(schemaIndex, tableIndex, rel)
|
||||
|
||||
se.pages.RemovePage("new-relationship")
|
||||
se.pages.RemovePage("relationships")
|
||||
se.showRelationshipList(schemaIndex, tableIndex)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
se.pages.RemovePage("new-relationship")
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" New Relationship ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("new-relationship")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("new-relationship", form, true, true)
|
||||
}
|
||||
|
||||
// showEditRelationshipDialog shows dialog to edit a relationship
|
||||
func (se *SchemaEditor) showEditRelationshipDialog(schemaIndex, tableIndex int, relName string) {
|
||||
table := se.GetTable(schemaIndex, tableIndex)
|
||||
if table == nil {
|
||||
return
|
||||
}
|
||||
|
||||
rel := se.GetRelationship(schemaIndex, tableIndex, relName)
|
||||
if rel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
form := tview.NewForm()
|
||||
|
||||
// Collect all tables for dropdown
|
||||
var allTables []string
|
||||
var tableMap []struct{ schemaIdx, tableIdx int }
|
||||
selectedTableIdx := 0
|
||||
for si, schema := range se.db.Schemas {
|
||||
for ti, t := range schema.Tables {
|
||||
tableName := t.Name
|
||||
if schema.Name != table.Schema {
|
||||
tableName = schema.Name + "." + t.Name
|
||||
}
|
||||
allTables = append(allTables, tableName)
|
||||
tableMap = append(tableMap, struct{ schemaIdx, tableIdx int }{si, ti})
|
||||
|
||||
// Check if this is the current target table
|
||||
if t.Name == rel.ToTable && schema.Name == rel.ToSchema {
|
||||
selectedTableIdx = len(allTables) - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newName := rel.Name
|
||||
relType := rel.Type
|
||||
fromColumns := strings.Join(rel.FromColumns, ", ")
|
||||
toColumns := strings.Join(rel.ToColumns, ", ")
|
||||
description := rel.Description
|
||||
|
||||
form.AddInputField("Name", rel.Name, 40, nil, func(value string) {
|
||||
newName = value
|
||||
})
|
||||
|
||||
// Find initial type index
|
||||
typeIdx := 1 // OneToMany default
|
||||
typeOptions := []string{
|
||||
string(models.OneToOne),
|
||||
string(models.OneToMany),
|
||||
string(models.ManyToMany),
|
||||
}
|
||||
for i, opt := range typeOptions {
|
||||
if opt == string(rel.Type) {
|
||||
typeIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
form.AddDropDown("Type", typeOptions, typeIdx, func(option string, optionIndex int) {
|
||||
relType = models.RelationType(option)
|
||||
})
|
||||
|
||||
form.AddInputField("From Columns (comma-separated)", fromColumns, 40, nil, func(value string) {
|
||||
fromColumns = value
|
||||
})
|
||||
|
||||
form.AddDropDown("To Table", allTables, selectedTableIdx, func(option string, optionIndex int) {
|
||||
selectedTableIdx = optionIndex
|
||||
})
|
||||
|
||||
form.AddInputField("To Columns (comma-separated)", toColumns, 40, nil, func(value string) {
|
||||
toColumns = value
|
||||
})
|
||||
|
||||
form.AddInputField("Description", rel.Description, 60, nil, func(value string) {
|
||||
description = value
|
||||
})
|
||||
|
||||
form.AddButton("Save", func() {
|
||||
if newName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse columns
|
||||
fromCols := strings.Split(fromColumns, ",")
|
||||
for i := range fromCols {
|
||||
fromCols[i] = strings.TrimSpace(fromCols[i])
|
||||
}
|
||||
|
||||
toCols := strings.Split(toColumns, ",")
|
||||
for i := range toCols {
|
||||
toCols[i] = strings.TrimSpace(toCols[i])
|
||||
}
|
||||
|
||||
// Get target table
|
||||
targetSchema := se.db.Schemas[tableMap[selectedTableIdx].schemaIdx]
|
||||
targetTable := targetSchema.Tables[tableMap[selectedTableIdx].tableIdx]
|
||||
|
||||
updatedRel := models.InitRelationship(newName, relType)
|
||||
updatedRel.FromTable = table.Name
|
||||
updatedRel.FromSchema = table.Schema
|
||||
updatedRel.FromColumns = fromCols
|
||||
updatedRel.ToTable = targetTable.Name
|
||||
updatedRel.ToSchema = targetTable.Schema
|
||||
updatedRel.ToColumns = toCols
|
||||
updatedRel.Description = description
|
||||
updatedRel.GUID = rel.GUID
|
||||
|
||||
se.UpdateRelationship(schemaIndex, tableIndex, relName, updatedRel)
|
||||
|
||||
se.pages.RemovePage("edit-relationship")
|
||||
se.pages.RemovePage("relationships")
|
||||
se.showRelationshipList(schemaIndex, tableIndex)
|
||||
})
|
||||
|
||||
form.AddButton("Back", func() {
|
||||
se.pages.RemovePage("edit-relationship")
|
||||
})
|
||||
|
||||
form.SetBorder(true).SetTitle(" Edit Relationship ").SetTitleAlign(tview.AlignLeft)
|
||||
form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("edit-relationship")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddPage("edit-relationship", form, true, true)
|
||||
}
|
||||
|
||||
// showDeleteRelationshipConfirm shows confirmation dialog for deleting a relationship
|
||||
func (se *SchemaEditor) showDeleteRelationshipConfirm(schemaIndex, tableIndex int, relName string) {
|
||||
modal := tview.NewModal().
|
||||
SetText(fmt.Sprintf("Delete relationship '%s'? This action cannot be undone.", relName)).
|
||||
AddButtons([]string{"Cancel", "Delete"}).
|
||||
SetDoneFunc(func(buttonIndex int, buttonLabel string) {
|
||||
if buttonLabel == "Delete" {
|
||||
se.DeleteRelationship(schemaIndex, tableIndex, relName)
|
||||
se.pages.RemovePage("delete-relationship-confirm")
|
||||
se.pages.RemovePage("relationships")
|
||||
se.showRelationshipList(schemaIndex, tableIndex)
|
||||
} else {
|
||||
se.pages.RemovePage("delete-relationship-confirm")
|
||||
}
|
||||
})
|
||||
|
||||
modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyEscape {
|
||||
se.pages.RemovePage("delete-relationship-confirm")
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
se.pages.AddAndSwitchToPage("delete-relationship-confirm", modal, true)
|
||||
}
|
||||
@@ -270,6 +270,9 @@ func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *mode
|
||||
se.showColumnEditor(schemaIndex, tableIndex, row-1, column)
|
||||
}
|
||||
})
|
||||
btnRelations := tview.NewButton("Relations [r]").SetSelectedFunc(func() {
|
||||
se.showRelationshipList(schemaIndex, tableIndex)
|
||||
})
|
||||
btnDelTable := tview.NewButton("Delete Table [d]").SetSelectedFunc(func() {
|
||||
se.showDeleteTableConfirm(schemaIndex, tableIndex)
|
||||
})
|
||||
@@ -308,6 +311,18 @@ func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *mode
|
||||
se.app.SetFocus(btnEditColumn)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnRelations)
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
})
|
||||
|
||||
btnRelations.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnEditTable)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
se.app.SetFocus(btnDelTable)
|
||||
return nil
|
||||
@@ -317,7 +332,7 @@ func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *mode
|
||||
|
||||
btnDelTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
if event.Key() == tcell.KeyBacktab {
|
||||
se.app.SetFocus(btnEditTable)
|
||||
se.app.SetFocus(btnRelations)
|
||||
return nil
|
||||
}
|
||||
if event.Key() == tcell.KeyTab {
|
||||
@@ -342,6 +357,7 @@ func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *mode
|
||||
btnFlex.AddItem(btnNewCol, 0, 1, true).
|
||||
AddItem(btnEditColumn, 0, 1, false).
|
||||
AddItem(btnEditTable, 0, 1, false).
|
||||
AddItem(btnRelations, 0, 1, false).
|
||||
AddItem(btnDelTable, 0, 1, false).
|
||||
AddItem(btnBack, 0, 1, false)
|
||||
|
||||
@@ -373,6 +389,10 @@ func (se *SchemaEditor) showTableEditor(schemaIndex, tableIndex int, table *mode
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'r' {
|
||||
se.showRelationshipList(schemaIndex, tableIndex)
|
||||
return nil
|
||||
}
|
||||
if event.Rune() == 'b' {
|
||||
se.pages.RemovePage("table-editor")
|
||||
se.pages.SwitchToPage("schema-editor")
|
||||
|
||||
@@ -106,15 +106,11 @@ func (td *TemplateData) FinalizeImports() {
|
||||
}
|
||||
|
||||
// NewModelData creates a new ModelData from a models.Table
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
|
||||
tableName := table.Name
|
||||
if schema != "" {
|
||||
tableName = schema + "." + table.Name
|
||||
}
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
|
||||
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
|
||||
|
||||
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||
singularTable := Singularize(table.Name)
|
||||
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||
tablePart := SnakeCaseToPascalCase(table.Name)
|
||||
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
|
||||
@@ -62,6 +62,17 @@ func (tm *TypeMapper) isSimpleType(sqlType string) bool {
|
||||
return simpleTypes[sqlType]
|
||||
}
|
||||
|
||||
// isSerialType checks if a SQL type is a serial type (auto-incrementing)
|
||||
func (tm *TypeMapper) isSerialType(sqlType string) bool {
|
||||
baseType := tm.extractBaseType(sqlType)
|
||||
serialTypes := map[string]bool{
|
||||
"serial": true,
|
||||
"bigserial": true,
|
||||
"smallserial": true,
|
||||
}
|
||||
return serialTypes[baseType]
|
||||
}
|
||||
|
||||
// baseGoType returns the base Go type for a SQL type (not null, simple types only)
|
||||
func (tm *TypeMapper) baseGoType(sqlType string) string {
|
||||
typeMap := map[string]string{
|
||||
@@ -122,10 +133,10 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
|
||||
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
|
||||
// Date/Time types
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"date": tm.sqlTypesAlias + ".SqlDate",
|
||||
"time": tm.sqlTypesAlias + ".SqlTime",
|
||||
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
@@ -190,6 +201,11 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
||||
parts = append(parts, "pk")
|
||||
}
|
||||
|
||||
// Auto increment (for serial types or explicit auto_increment)
|
||||
if column.AutoIncrement || tm.isSerialType(column.Type) {
|
||||
parts = append(parts, "autoincrement")
|
||||
}
|
||||
|
||||
// Default value
|
||||
if column.Default != nil {
|
||||
// Sanitize default value to remove backticks
|
||||
@@ -251,7 +267,15 @@ func (tm *TypeMapper) BuildRelationshipTag(constraint *models.Constraint, relTyp
|
||||
if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
|
||||
localCol := constraint.Columns[0]
|
||||
foreignCol := constraint.ReferencedColumns[0]
|
||||
parts = append(parts, fmt.Sprintf("join:%s=%s", localCol, foreignCol))
|
||||
|
||||
// For has-many relationships, swap the columns
|
||||
// has-one: join:fk_in_this_table=pk_in_other_table
|
||||
// has-many: join:pk_in_this_table=fk_in_other_table
|
||||
if relType == "has-many" {
|
||||
parts = append(parts, fmt.Sprintf("join:%s=%s", foreignCol, localCol))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("join:%s=%s", localCol, foreignCol))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, ",")
|
||||
|
||||
@@ -86,7 +86,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
// Collect all models
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -181,7 +181,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Create model data
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -318,8 +318,7 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
||||
|
||||
// getModelName generates the model name from schema and table name
|
||||
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
tablePart := SnakeCaseToPascalCase(singular)
|
||||
tablePart := SnakeCaseToPascalCase(tableName)
|
||||
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
// Verify key elements are present
|
||||
expectations := []string{
|
||||
"package models",
|
||||
"type ModelPublicUser struct",
|
||||
"type ModelPublicUsers struct",
|
||||
"bun.BaseModel",
|
||||
"table:public.users",
|
||||
"alias:users",
|
||||
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
"resolvespec_common.SqlTime",
|
||||
"bun:\"id",
|
||||
"bun:\"email",
|
||||
"func (m ModelPublicUser) TableName() string",
|
||||
"func (m ModelPublicUsers) TableName() string",
|
||||
"return \"public.users\"",
|
||||
"func (m ModelPublicUser) GetID() int64",
|
||||
"func (m ModelPublicUsers) GetID() int64",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
@@ -90,8 +90,8 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify Bun-specific elements
|
||||
if !strings.Contains(generated, "bun:\"id,type:bigint,pk,") {
|
||||
t.Errorf("Missing Bun-style primary key tag")
|
||||
if !strings.Contains(generated, "bun:\"id,type:bigint,pk,autoincrement,") {
|
||||
t.Errorf("Missing Bun-style primary key tag with autoincrement")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,14 +308,20 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||
filepointerStr := string(filepointerContent)
|
||||
|
||||
// Should have two different has-many relationships with unique names
|
||||
hasManyExpectations := []string{
|
||||
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||
hasManyExpectations := []struct {
|
||||
fieldName string
|
||||
tag string
|
||||
}{
|
||||
{"RelRIDFilepointerRequestOrgAPIEvents", "join:id_filepointer=rid_filepointer_request"}, // Has many via rid_filepointer_request
|
||||
{"RelRIDFilepointerResponseOrgAPIEvents", "join:id_filepointer=rid_filepointer_response"}, // Has many via rid_filepointer_response
|
||||
}
|
||||
|
||||
for _, exp := range hasManyExpectations {
|
||||
if !strings.Contains(filepointerStr, exp) {
|
||||
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
||||
if !strings.Contains(filepointerStr, exp.fieldName) {
|
||||
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp.fieldName, filepointerStr)
|
||||
}
|
||||
if !strings.Contains(filepointerStr, exp.tag) {
|
||||
t.Errorf("Missing has-many relationship join tag: %s\nGenerated:\n%s", exp.tag, filepointerStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -561,8 +567,8 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||
{"bigint", false, "resolvespec_common.SqlInt64"},
|
||||
{"varchar", true, "resolvespec_common.SqlString"}, // Bun uses sql types even for NOT NULL strings
|
||||
{"varchar", false, "resolvespec_common.SqlString"},
|
||||
{"timestamp", true, "resolvespec_common.SqlTime"},
|
||||
{"timestamp", false, "resolvespec_common.SqlTime"},
|
||||
{"timestamp", true, "resolvespec_common.SqlTimeStamp"},
|
||||
{"timestamp", false, "resolvespec_common.SqlTimeStamp"},
|
||||
{"date", false, "resolvespec_common.SqlDate"},
|
||||
{"boolean", true, "bool"},
|
||||
{"boolean", false, "resolvespec_common.SqlBool"},
|
||||
@@ -618,6 +624,37 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||
},
|
||||
want: []string{"status,", "type:text,", "default:active,"},
|
||||
},
|
||||
{
|
||||
name: "auto increment with AutoIncrement flag",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
want: []string{"id,", "type:bigint,", "pk,", "autoincrement,"},
|
||||
},
|
||||
{
|
||||
name: "serial type (auto-increment)",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "serial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
want: []string{"id,", "type:serial,", "pk,", "autoincrement,"},
|
||||
},
|
||||
{
|
||||
name: "bigserial type (auto-increment)",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
67
pkg/writers/doc.go
Normal file
67
pkg/writers/doc.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Package writers provides interfaces and implementations for writing database schemas
|
||||
// to various output formats and destinations.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The writers package defines a common Writer interface that all format-specific writers
|
||||
// implement. This allows RelSpec to export database schemas to multiple formats including:
|
||||
// - SQL schema files (PostgreSQL, SQLite)
|
||||
// - Schema definition files (DBML, DCTX, DrawDB, GraphQL)
|
||||
// - ORM model files (GORM, Bun, Drizzle, Prisma, TypeORM)
|
||||
// - Data interchange formats (JSON, YAML)
|
||||
//
|
||||
// # Architecture
|
||||
//
|
||||
// Each writer implementation is located in its own subpackage (e.g., pkg/writers/dbml,
|
||||
// pkg/writers/pgsql) and implements the Writer interface, supporting three levels of
|
||||
// granularity:
|
||||
// - WriteDatabase() - Write complete database with all schemas
|
||||
// - WriteSchema() - Write single schema with all tables
|
||||
// - WriteTable() - Write single table with all columns and metadata
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// Writers are instantiated with WriterOptions containing destination-specific configuration:
|
||||
//
|
||||
// // Write to file
|
||||
// writer := dbml.NewWriter(&writers.WriterOptions{
|
||||
// OutputPath: "schema.dbml",
|
||||
// })
|
||||
// err := writer.WriteDatabase(db)
|
||||
//
|
||||
// // Write ORM models with package name
|
||||
// writer := gorm.NewWriter(&writers.WriterOptions{
|
||||
// OutputPath: "./models",
|
||||
// PackageName: "models",
|
||||
// })
|
||||
// err := writer.WriteDatabase(db)
|
||||
//
|
||||
// // Write with schema flattening for SQLite
|
||||
// writer := sqlite.NewWriter(&writers.WriterOptions{
|
||||
// OutputPath: "schema.sql",
|
||||
// FlattenSchema: true,
|
||||
// })
|
||||
// err := writer.WriteDatabase(db)
|
||||
//
|
||||
// # Schema Flattening
|
||||
//
|
||||
// The FlattenSchema option controls how schema-qualified table names are handled:
|
||||
// - false (default): Uses dot notation (schema.table)
|
||||
// - true: Joins with underscore (schema_table), useful for SQLite
|
||||
//
|
||||
// # Supported Formats
|
||||
//
|
||||
// - dbml: Database Markup Language files
|
||||
// - dctx: DCTX schema files
|
||||
// - drawdb: DrawDB JSON format
|
||||
// - graphql: GraphQL schema definition language
|
||||
// - json: JSON database schema
|
||||
// - yaml: YAML database schema
|
||||
// - gorm: Go GORM model structs
|
||||
// - bun: Go Bun model structs
|
||||
// - drizzle: TypeScript Drizzle ORM schemas
|
||||
// - prisma: Prisma schema language
|
||||
// - typeorm: TypeScript TypeORM entities
|
||||
// - pgsql: PostgreSQL SQL schema
|
||||
// - sqlite: SQLite SQL schema with automatic flattening
|
||||
package writers
|
||||
@@ -105,15 +105,11 @@ func (td *TemplateData) FinalizeImports() {
|
||||
}
|
||||
|
||||
// NewModelData creates a new ModelData from a models.Table
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
|
||||
tableName := table.Name
|
||||
if schema != "" {
|
||||
tableName = schema + "." + table.Name
|
||||
}
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper, flattenSchema bool) *ModelData {
|
||||
tableName := writers.QualifiedTableName(schema, table.Name, flattenSchema)
|
||||
|
||||
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||
singularTable := Singularize(table.Name)
|
||||
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||
tablePart := SnakeCaseToPascalCase(table.Name)
|
||||
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
|
||||
@@ -158,10 +158,10 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
|
||||
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
|
||||
// Date/Time types
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"date": tm.sqlTypesAlias + ".SqlDate",
|
||||
"time": tm.sqlTypesAlias + ".SqlTime",
|
||||
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
|
||||
@@ -83,7 +83,7 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
// Collect all models
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -175,7 +175,7 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Create model data
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper, w.options.FlattenSchema)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
@@ -312,8 +312,7 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
||||
|
||||
// getModelName generates the model name from schema and table name
|
||||
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
tablePart := SnakeCaseToPascalCase(singular)
|
||||
tablePart := SnakeCaseToPascalCase(tableName)
|
||||
|
||||
// Include schema name in model name
|
||||
var modelName string
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
// Verify key elements are present
|
||||
expectations := []string{
|
||||
"package models",
|
||||
"type ModelPublicUser struct",
|
||||
"type ModelPublicUsers struct",
|
||||
"ID",
|
||||
"int64",
|
||||
"Email",
|
||||
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
"time.Time",
|
||||
"gorm:\"column:id",
|
||||
"gorm:\"column:email",
|
||||
"func (m ModelPublicUser) TableName() string",
|
||||
"func (m ModelPublicUsers) TableName() string",
|
||||
"return \"public.users\"",
|
||||
"func (m ModelPublicUser) GetID() int64",
|
||||
"func (m ModelPublicUsers) GetID() int64",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
@@ -655,7 +655,7 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
||||
{"varchar", true, "string"},
|
||||
{"varchar", false, "sql_types.SqlString"},
|
||||
{"timestamp", true, "time.Time"},
|
||||
{"timestamp", false, "sql_types.SqlTime"},
|
||||
{"timestamp", false, "sql_types.SqlTimeStamp"},
|
||||
{"boolean", true, "bool"},
|
||||
{"boolean", false, "sql_types.SqlBool"},
|
||||
}
|
||||
|
||||
130
pkg/writers/mssql/README.md
Normal file
130
pkg/writers/mssql/README.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# MSSQL Writer
|
||||
|
||||
Generates Microsoft SQL Server DDL (Data Definition Language) from database schema models.
|
||||
|
||||
## Features
|
||||
|
||||
- **DDL Generation**: Generates complete SQL scripts for creating MSSQL schema
|
||||
- **Schema Support**: Creates multiple schemas with proper naming
|
||||
- **Bracket Notation**: Uses [schema].[table] bracket notation for identifiers
|
||||
- **Identity Columns**: Generates IDENTITY(1,1) for auto-increment columns
|
||||
- **Constraints**: Generates primary keys, foreign keys, unique, and check constraints
|
||||
- **Indexes**: Creates indexes with unique support
|
||||
- **Extended Properties**: Uses sp_addextendedproperty for comments
|
||||
- **Direct Execution**: Can directly execute DDL on MSSQL database
|
||||
- **Schema Flattening**: Optional schema flattening for compatibility
|
||||
|
||||
## Features by Phase
|
||||
|
||||
1. **Phase 1**: Create schemas
|
||||
2. **Phase 2**: Create tables with columns, identity, and defaults
|
||||
3. **Phase 3**: Add primary key constraints
|
||||
4. **Phase 4**: Create indexes
|
||||
5. **Phase 5**: Add unique constraints
|
||||
6. **Phase 6**: Add check constraints
|
||||
7. **Phase 7**: Add foreign key constraints
|
||||
8. **Phase 8**: Add extended properties (comments)
|
||||
|
||||
## Type Mappings
|
||||
|
||||
| Canonical Type | MSSQL Type |
|
||||
|----------------|-----------|
|
||||
| int | INT |
|
||||
| int64 | BIGINT |
|
||||
| int16 | SMALLINT |
|
||||
| int8 | TINYINT |
|
||||
| bool | BIT |
|
||||
| float32 | REAL |
|
||||
| float64 | FLOAT |
|
||||
| decimal | NUMERIC |
|
||||
| string | NVARCHAR(255) |
|
||||
| text | NVARCHAR(MAX) |
|
||||
| timestamp | DATETIME2 |
|
||||
| timestamptz | DATETIMEOFFSET |
|
||||
| uuid | UNIQUEIDENTIFIER |
|
||||
| bytea | VARBINARY(MAX) |
|
||||
| date | DATE |
|
||||
| time | TIME |
|
||||
|
||||
## Usage
|
||||
|
||||
### Generate SQL File
|
||||
|
||||
```go
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/writers/mssql"
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
|
||||
writer := mssql.NewWriter(&writers.WriterOptions{
|
||||
OutputPath: "schema.sql",
|
||||
FlattenSchema: false,
|
||||
})
|
||||
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
### Direct Database Execution
|
||||
|
||||
```go
|
||||
writer := mssql.NewWriter(&writers.WriterOptions{
|
||||
OutputPath: "",
|
||||
Metadata: map[string]interface{}{
|
||||
"connection_string": "sqlserver://sa:password@localhost/newdb",
|
||||
},
|
||||
})
|
||||
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
### CLI Usage
|
||||
|
||||
Generate SQL file:
|
||||
```bash
|
||||
relspec convert --from json --from-path schema.json \
|
||||
--to mssql --to-path schema.sql
|
||||
```
|
||||
|
||||
Execute directly to database:
|
||||
```bash
|
||||
relspec convert --from json --from-path schema.json \
|
||||
--to mssql \
|
||||
--metadata '{"connection_string":"sqlserver://sa:password@localhost/mydb"}'
|
||||
```
|
||||
|
||||
## Default Values
|
||||
|
||||
The writer supports several default value patterns:
|
||||
- Functions: `GETDATE()`, `CURRENT_TIMESTAMP()`
|
||||
- Literals: strings wrapped in quotes, numbers, booleans (0/1 for BIT)
|
||||
- CAST expressions
|
||||
|
||||
## Comments/Extended Properties
|
||||
|
||||
Table and column descriptions are stored as MS_Description extended properties:
|
||||
|
||||
```sql
|
||||
EXEC sp_addextendedproperty
|
||||
@name = 'MS_Description',
|
||||
@value = 'Table description here',
|
||||
@level0type = 'SCHEMA', @level0name = 'dbo',
|
||||
@level1type = 'TABLE', @level1name = 'my_table';
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests with:
|
||||
```bash
|
||||
go test ./pkg/writers/mssql/...
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Views are not currently supported in the writer
|
||||
- Sequences are not supported (MSSQL uses IDENTITY instead)
|
||||
- Partitioning and advanced features are not supported
|
||||
- Generated DDL assumes no triggers or computed columns
|
||||
579
pkg/writers/mssql/writer.go
Normal file
579
pkg/writers/mssql/writer.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
_ "github.com/microsoft/go-mssqldb" // MSSQL driver
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/mssql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// Writer implements the Writer interface for MSSQL SQL output
|
||||
type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
// NewWriter creates a new MSSQL SQL writer
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
return &Writer{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
// qualTable returns a schema-qualified name using bracket notation
|
||||
func (w *Writer) qualTable(schema, name string) string {
|
||||
if w.options.FlattenSchema {
|
||||
return fmt.Sprintf("[%s]", name)
|
||||
}
|
||||
return fmt.Sprintf("[%s].[%s]", schema, name)
|
||||
}
|
||||
|
||||
// WriteDatabase writes the entire database schema as SQL
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
// Check if we should execute SQL directly on a database
|
||||
if connString, ok := w.options.Metadata["connection_string"].(string); ok && connString != "" {
|
||||
return w.executeDatabaseSQL(db, connString)
|
||||
}
|
||||
|
||||
var writer io.Writer
|
||||
var file *os.File
|
||||
var err error
|
||||
|
||||
// Use existing writer if already set (for testing)
|
||||
if w.writer != nil {
|
||||
writer = w.writer
|
||||
} else if w.options.OutputPath != "" {
|
||||
// Determine output destination
|
||||
file, err = os.Create(w.options.OutputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
writer = file
|
||||
} else {
|
||||
writer = os.Stdout
|
||||
}
|
||||
|
||||
w.writer = writer
|
||||
|
||||
// Write header comment
|
||||
fmt.Fprintf(w.writer, "-- MSSQL Database Schema\n")
|
||||
fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name)
|
||||
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n")
|
||||
|
||||
// Process each schema in the database
|
||||
for _, schema := range db.Schemas {
|
||||
if err := w.WriteSchema(schema); err != nil {
|
||||
return fmt.Errorf("failed to write schema %s: %w", schema.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteSchema writes a single schema and all its tables
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
if w.writer == nil {
|
||||
w.writer = os.Stdout
|
||||
}
|
||||
|
||||
// Phase 1: Create schema (skip dbo schema and when flattening)
|
||||
if schema.Name != "dbo" && !w.options.FlattenSchema {
|
||||
fmt.Fprintf(w.writer, "-- Schema: %s\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, "CREATE SCHEMA [%s];\n\n", schema.Name)
|
||||
}
|
||||
|
||||
// Phase 2: Create tables with columns
|
||||
fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writeCreateTable(schema, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Primary keys
|
||||
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writePrimaryKey(schema, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 4: Indexes
|
||||
fmt.Fprintf(w.writer, "-- Indexes for schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writeIndexes(schema, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Unique constraints
|
||||
fmt.Fprintf(w.writer, "-- Unique constraints for schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writeUniqueConstraints(schema, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 6: Check constraints
|
||||
fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writeCheckConstraints(schema, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 7: Foreign keys
|
||||
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writeForeignKeys(schema, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 8: Comments
|
||||
fmt.Fprintf(w.writer, "-- Comments for schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
if err := w.writeComments(schema, table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteTable writes a single table with all its elements
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
if w.writer == nil {
|
||||
w.writer = os.Stdout
|
||||
}
|
||||
|
||||
// Create a temporary schema with just this table
|
||||
schema := models.InitSchema(table.Schema)
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
|
||||
return w.WriteSchema(schema)
|
||||
}
|
||||
|
||||
// writeCreateTable generates CREATE TABLE statement
|
||||
func (w *Writer) writeCreateTable(schema *models.Schema, table *models.Table) error {
|
||||
fmt.Fprintf(w.writer, "CREATE TABLE %s (\n", w.qualTable(schema.Name, table.Name))
|
||||
|
||||
// Sort columns by sequence
|
||||
columns := getSortedColumns(table.Columns)
|
||||
columnDefs := make([]string, 0, len(columns))
|
||||
|
||||
for _, col := range columns {
|
||||
def := w.generateColumnDefinition(col)
|
||||
columnDefs = append(columnDefs, " "+def)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "%s\n", strings.Join(columnDefs, ",\n"))
|
||||
fmt.Fprintf(w.writer, ");\n\n")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateColumnDefinition generates MSSQL column definition
|
||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
parts := []string{fmt.Sprintf("[%s]", col.Name)}
|
||||
|
||||
// Type with length/precision
|
||||
baseType := mssql.ConvertCanonicalToMSSQL(col.Type)
|
||||
typeStr := baseType
|
||||
|
||||
// Handle specific type parameters for MSSQL
|
||||
if col.Length > 0 && col.Precision == 0 {
|
||||
// String types with length - override the default length from baseType
|
||||
if strings.HasPrefix(baseType, "NVARCHAR") || strings.HasPrefix(baseType, "VARCHAR") ||
|
||||
strings.HasPrefix(baseType, "CHAR") || strings.HasPrefix(baseType, "NCHAR") {
|
||||
if col.Length > 0 && col.Length < 8000 {
|
||||
// Extract base type without length specification
|
||||
baseName := strings.Split(baseType, "(")[0]
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseName, col.Length)
|
||||
}
|
||||
}
|
||||
} else if col.Precision > 0 {
|
||||
// Numeric types with precision/scale
|
||||
baseName := strings.Split(baseType, "(")[0]
|
||||
if col.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseName, col.Precision, col.Scale)
|
||||
} else {
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseName, col.Precision)
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, typeStr)
|
||||
|
||||
// IDENTITY for auto-increment
|
||||
if col.AutoIncrement {
|
||||
parts = append(parts, "IDENTITY(1,1)")
|
||||
}
|
||||
|
||||
// NOT NULL
|
||||
if col.NotNull {
|
||||
parts = append(parts, "NOT NULL")
|
||||
}
|
||||
|
||||
// DEFAULT
|
||||
if col.Default != nil {
|
||||
switch v := col.Default.(type) {
|
||||
case string:
|
||||
cleanDefault := stripBackticks(v)
|
||||
if strings.HasPrefix(strings.ToUpper(cleanDefault), "GETDATE") ||
|
||||
strings.HasPrefix(strings.ToUpper(cleanDefault), "CURRENT_") {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
||||
} else if cleanDefault == "true" || cleanDefault == "false" {
|
||||
if cleanDefault == "true" {
|
||||
parts = append(parts, "DEFAULT 1")
|
||||
} else {
|
||||
parts = append(parts, "DEFAULT 0")
|
||||
}
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
parts = append(parts, "DEFAULT 1")
|
||||
} else {
|
||||
parts = append(parts, "DEFAULT 0")
|
||||
}
|
||||
case int, int64:
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// writePrimaryKey generates ALTER TABLE statement for primary key
|
||||
func (w *Writer) writePrimaryKey(schema *models.Schema, table *models.Table) error {
|
||||
// Find primary key constraint
|
||||
var pkConstraint *models.Constraint
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.PrimaryKeyConstraint {
|
||||
pkConstraint = constraint
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var columnNames []string
|
||||
pkName := fmt.Sprintf("PK_%s_%s", schema.Name, table.Name)
|
||||
|
||||
if pkConstraint != nil {
|
||||
pkName = pkConstraint.Name
|
||||
columnNames = make([]string, 0, len(pkConstraint.Columns))
|
||||
for _, colName := range pkConstraint.Columns {
|
||||
columnNames = append(columnNames, fmt.Sprintf("[%s]", colName))
|
||||
}
|
||||
} else {
|
||||
// Check for columns with IsPrimaryKey = true
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
columnNames = append(columnNames, fmt.Sprintf("[%s]", col.Name))
|
||||
}
|
||||
}
|
||||
sort.Strings(columnNames)
|
||||
}
|
||||
|
||||
if len(columnNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] PRIMARY KEY (%s);\n\n",
|
||||
w.qualTable(schema.Name, table.Name), pkName, strings.Join(columnNames, ", "))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeIndexes generates CREATE INDEX statements
|
||||
func (w *Writer) writeIndexes(schema *models.Schema, table *models.Table) error {
|
||||
// Sort indexes by name
|
||||
indexNames := make([]string, 0, len(table.Indexes))
|
||||
for name := range table.Indexes {
|
||||
indexNames = append(indexNames, name)
|
||||
}
|
||||
sort.Strings(indexNames)
|
||||
|
||||
for _, name := range indexNames {
|
||||
index := table.Indexes[name]
|
||||
|
||||
// Skip if it's a primary key index
|
||||
if strings.HasPrefix(strings.ToLower(index.Name), "pk_") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build column list
|
||||
columnExprs := make([]string, 0, len(index.Columns))
|
||||
for _, colName := range index.Columns {
|
||||
columnExprs = append(columnExprs, fmt.Sprintf("[%s]", colName))
|
||||
}
|
||||
|
||||
if len(columnExprs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
unique := ""
|
||||
if index.Unique {
|
||||
unique = "UNIQUE "
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "CREATE %sINDEX [%s] ON %s (%s);\n\n",
|
||||
unique, index.Name, w.qualTable(schema.Name, table.Name), strings.Join(columnExprs, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeUniqueConstraints generates ALTER TABLE statements for unique constraints
|
||||
func (w *Writer) writeUniqueConstraints(schema *models.Schema, table *models.Table) error {
|
||||
// Sort constraints by name
|
||||
constraintNames := make([]string, 0)
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Build column list
|
||||
columnExprs := make([]string, 0, len(constraint.Columns))
|
||||
for _, colName := range constraint.Columns {
|
||||
columnExprs = append(columnExprs, fmt.Sprintf("[%s]", colName))
|
||||
}
|
||||
|
||||
if len(columnExprs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] UNIQUE (%s);\n\n",
|
||||
w.qualTable(schema.Name, table.Name), constraint.Name, strings.Join(columnExprs, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCheckConstraints generates ALTER TABLE statements for check constraints
|
||||
func (w *Writer) writeCheckConstraints(schema *models.Schema, table *models.Table) error {
|
||||
// Sort constraints by name
|
||||
constraintNames := make([]string, 0)
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.CheckConstraint {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
if constraint.Expression == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] CHECK (%s);\n\n",
|
||||
w.qualTable(schema.Name, table.Name), constraint.Name, constraint.Expression)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeForeignKeys generates ALTER TABLE statements for foreign keys
|
||||
func (w *Writer) writeForeignKeys(schema *models.Schema, table *models.Table) error {
|
||||
// Process foreign key constraints
|
||||
constraintNames := make([]string, 0)
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Build column lists
|
||||
sourceColumns := make([]string, 0, len(constraint.Columns))
|
||||
for _, colName := range constraint.Columns {
|
||||
sourceColumns = append(sourceColumns, fmt.Sprintf("[%s]", colName))
|
||||
}
|
||||
|
||||
targetColumns := make([]string, 0, len(constraint.ReferencedColumns))
|
||||
for _, colName := range constraint.ReferencedColumns {
|
||||
targetColumns = append(targetColumns, fmt.Sprintf("[%s]", colName))
|
||||
}
|
||||
|
||||
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
refSchema := constraint.ReferencedSchema
|
||||
if refSchema == "" {
|
||||
refSchema = schema.Name
|
||||
}
|
||||
|
||||
onDelete := "NO ACTION"
|
||||
if constraint.OnDelete != "" {
|
||||
onDelete = strings.ToUpper(constraint.OnDelete)
|
||||
}
|
||||
|
||||
onUpdate := "NO ACTION"
|
||||
if constraint.OnUpdate != "" {
|
||||
onUpdate = strings.ToUpper(constraint.OnUpdate)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] FOREIGN KEY (%s)\n",
|
||||
w.qualTable(schema.Name, table.Name), constraint.Name, strings.Join(sourceColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " REFERENCES %s (%s)\n",
|
||||
w.qualTable(refSchema, constraint.ReferencedTable), strings.Join(targetColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " ON DELETE %s ON UPDATE %s;\n\n",
|
||||
onDelete, onUpdate)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeComments generates EXEC sp_addextendedproperty statements for table and column descriptions
|
||||
func (w *Writer) writeComments(schema *models.Schema, table *models.Table) error {
|
||||
// Table comment
|
||||
if table.Description != "" {
|
||||
fmt.Fprintf(w.writer, "EXEC sp_addextendedproperty\n")
|
||||
fmt.Fprintf(w.writer, " @name = 'MS_Description',\n")
|
||||
fmt.Fprintf(w.writer, " @value = '%s',\n", escapeQuote(table.Description))
|
||||
fmt.Fprintf(w.writer, " @level0type = 'SCHEMA', @level0name = '%s',\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " @level1type = 'TABLE', @level1name = '%s';\n\n", table.Name)
|
||||
}
|
||||
|
||||
// Column comments
|
||||
for _, col := range getSortedColumns(table.Columns) {
|
||||
if col.Description != "" {
|
||||
fmt.Fprintf(w.writer, "EXEC sp_addextendedproperty\n")
|
||||
fmt.Fprintf(w.writer, " @name = 'MS_Description',\n")
|
||||
fmt.Fprintf(w.writer, " @value = '%s',\n", escapeQuote(col.Description))
|
||||
fmt.Fprintf(w.writer, " @level0type = 'SCHEMA', @level0name = '%s',\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " @level1type = 'TABLE', @level1name = '%s',\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " @level2type = 'COLUMN', @level2name = '%s';\n\n", col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeDatabaseSQL executes SQL statements directly on an MSSQL database
|
||||
func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) error {
|
||||
// Generate SQL statements
|
||||
statements := []string{}
|
||||
statements = append(statements, "-- MSSQL Database Schema")
|
||||
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
|
||||
statements = append(statements, "-- Generated by RelSpec")
|
||||
|
||||
for _, schema := range db.Schemas {
|
||||
if err := w.generateSchemaStatements(schema, &statements); err != nil {
|
||||
return fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to database
|
||||
dbConn, err := sql.Open("mssql", connString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
if err = dbConn.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
// Execute statements
|
||||
executedCount := 0
|
||||
for i, stmt := range statements {
|
||||
stmtTrimmed := strings.TrimSpace(stmt)
|
||||
|
||||
// Skip comments and empty statements
|
||||
if strings.HasPrefix(stmtTrimmed, "--") || stmtTrimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements))
|
||||
|
||||
_, execErr := dbConn.ExecContext(ctx, stmt)
|
||||
if execErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "⚠ Warning: Statement failed: %v\n", execErr)
|
||||
continue
|
||||
}
|
||||
|
||||
executedCount++
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "✓ Successfully executed %d statements\n", executedCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateSchemaStatements generates SQL statements for a schema
|
||||
func (w *Writer) generateSchemaStatements(schema *models.Schema, statements *[]string) error {
|
||||
// Phase 1: Create schema
|
||||
if schema.Name != "dbo" && !w.options.FlattenSchema {
|
||||
*statements = append(*statements, fmt.Sprintf("-- Schema: %s", schema.Name))
|
||||
*statements = append(*statements, fmt.Sprintf("CREATE SCHEMA [%s];", schema.Name))
|
||||
}
|
||||
|
||||
// Phase 2: Create tables
|
||||
*statements = append(*statements, fmt.Sprintf("-- Tables for schema: %s", schema.Name))
|
||||
for _, table := range schema.Tables {
|
||||
createTableSQL := fmt.Sprintf("CREATE TABLE %s (", w.qualTable(schema.Name, table.Name))
|
||||
columnDefs := make([]string, 0)
|
||||
|
||||
columns := getSortedColumns(table.Columns)
|
||||
for _, col := range columns {
|
||||
def := w.generateColumnDefinition(col)
|
||||
columnDefs = append(columnDefs, " "+def)
|
||||
}
|
||||
|
||||
createTableSQL += "\n" + strings.Join(columnDefs, ",\n") + "\n)"
|
||||
*statements = append(*statements, createTableSQL)
|
||||
}
|
||||
|
||||
// Phase 3-7: Constraints and indexes will be added by WriteSchema logic
|
||||
// For now, just create tables
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// getSortedColumns returns columns sorted by sequence
|
||||
func getSortedColumns(columns map[string]*models.Column) []*models.Column {
|
||||
names := make([]string, 0, len(columns))
|
||||
for name := range columns {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
sorted := make([]*models.Column, 0, len(columns))
|
||||
for _, name := range names {
|
||||
sorted = append(sorted, columns[name])
|
||||
}
|
||||
return sorted
|
||||
}
|
||||
|
||||
// escapeQuote escapes single quotes in strings for SQL
|
||||
func escapeQuote(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// stripBackticks removes backticks from SQL expressions
|
||||
func stripBackticks(s string) string {
|
||||
return strings.ReplaceAll(s, "`", "")
|
||||
}
|
||||
205
pkg/writers/mssql/writer_test.go
Normal file
205
pkg/writers/mssql/writer_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestGenerateColumnDefinition tests column definition generation
|
||||
func TestGenerateColumnDefinition(t *testing.T) {
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
column *models.Column
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "INT NOT NULL",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "int",
|
||||
NotNull: true,
|
||||
Sequence: 1,
|
||||
},
|
||||
expected: "[id] INT NOT NULL",
|
||||
},
|
||||
{
|
||||
name: "VARCHAR with length",
|
||||
column: &models.Column{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Length: 100,
|
||||
NotNull: true,
|
||||
Sequence: 2,
|
||||
},
|
||||
expected: "[name] NVARCHAR(100) NOT NULL",
|
||||
},
|
||||
{
|
||||
name: "DATETIME2 with default",
|
||||
column: &models.Column{
|
||||
Name: "created_at",
|
||||
Type: "timestamp",
|
||||
NotNull: true,
|
||||
Default: "GETDATE()",
|
||||
Sequence: 3,
|
||||
},
|
||||
expected: "[created_at] DATETIME2 NOT NULL DEFAULT GETDATE()",
|
||||
},
|
||||
{
|
||||
name: "IDENTITY column",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "int",
|
||||
AutoIncrement: true,
|
||||
NotNull: true,
|
||||
Sequence: 1,
|
||||
},
|
||||
expected: "[id] INT IDENTITY(1,1) NOT NULL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := writer.generateColumnDefinition(tt.column)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteCreateTable tests CREATE TABLE statement generation
|
||||
func TestWriteCreateTable(t *testing.T) {
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
|
||||
// Create a test schema with a table
|
||||
schema := models.InitSchema("dbo")
|
||||
table := models.InitTable("users", "dbo")
|
||||
|
||||
col1 := models.InitColumn("id", "users", "dbo")
|
||||
col1.Type = "int"
|
||||
col1.AutoIncrement = true
|
||||
col1.NotNull = true
|
||||
col1.Sequence = 1
|
||||
|
||||
col2 := models.InitColumn("email", "users", "dbo")
|
||||
col2.Type = "string"
|
||||
col2.Length = 255
|
||||
col2.NotNull = true
|
||||
col2.Sequence = 2
|
||||
|
||||
table.Columns["id"] = col1
|
||||
table.Columns["email"] = col2
|
||||
|
||||
// Write to buffer
|
||||
buf := &bytes.Buffer{}
|
||||
writer.writer = buf
|
||||
|
||||
err := writer.writeCreateTable(schema, table)
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "CREATE TABLE [dbo].[users]")
|
||||
assert.Contains(t, output, "[id] INT IDENTITY(1,1) NOT NULL")
|
||||
assert.Contains(t, output, "[email] NVARCHAR(255) NOT NULL")
|
||||
}
|
||||
|
||||
// TestWritePrimaryKey tests PRIMARY KEY constraint generation
|
||||
func TestWritePrimaryKey(t *testing.T) {
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
|
||||
schema := models.InitSchema("dbo")
|
||||
table := models.InitTable("users", "dbo")
|
||||
|
||||
// Add primary key constraint
|
||||
pk := models.InitConstraint("PK_users_id", models.PrimaryKeyConstraint)
|
||||
pk.Columns = []string{"id"}
|
||||
table.Constraints[pk.Name] = pk
|
||||
|
||||
// Add column
|
||||
col := models.InitColumn("id", "users", "dbo")
|
||||
col.Type = "int"
|
||||
col.Sequence = 1
|
||||
table.Columns["id"] = col
|
||||
|
||||
// Write to buffer
|
||||
buf := &bytes.Buffer{}
|
||||
writer.writer = buf
|
||||
|
||||
err := writer.writePrimaryKey(schema, table)
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "ALTER TABLE [dbo].[users]")
|
||||
assert.Contains(t, output, "PRIMARY KEY")
|
||||
assert.Contains(t, output, "[id]")
|
||||
}
|
||||
|
||||
// TestWriteForeignKey tests FOREIGN KEY constraint generation
|
||||
func TestWriteForeignKey(t *testing.T) {
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
|
||||
schema := models.InitSchema("dbo")
|
||||
table := models.InitTable("orders", "dbo")
|
||||
|
||||
// Add foreign key constraint
|
||||
fk := models.InitConstraint("FK_orders_users", models.ForeignKeyConstraint)
|
||||
fk.Columns = []string{"user_id"}
|
||||
fk.ReferencedSchema = "dbo"
|
||||
fk.ReferencedTable = "users"
|
||||
fk.ReferencedColumns = []string{"id"}
|
||||
fk.OnDelete = "CASCADE"
|
||||
fk.OnUpdate = "NO ACTION"
|
||||
table.Constraints[fk.Name] = fk
|
||||
|
||||
// Add column
|
||||
col := models.InitColumn("user_id", "orders", "dbo")
|
||||
col.Type = "int"
|
||||
col.Sequence = 1
|
||||
table.Columns["user_id"] = col
|
||||
|
||||
// Write to buffer
|
||||
buf := &bytes.Buffer{}
|
||||
writer.writer = buf
|
||||
|
||||
err := writer.writeForeignKeys(schema, table)
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "ALTER TABLE [dbo].[orders]")
|
||||
assert.Contains(t, output, "FK_orders_users")
|
||||
assert.Contains(t, output, "FOREIGN KEY")
|
||||
assert.Contains(t, output, "CASCADE")
|
||||
assert.Contains(t, output, "NO ACTION")
|
||||
}
|
||||
|
||||
// TestWriteComments tests extended property generation for comments
|
||||
func TestWriteComments(t *testing.T) {
|
||||
writer := NewWriter(&writers.WriterOptions{})
|
||||
|
||||
schema := models.InitSchema("dbo")
|
||||
table := models.InitTable("users", "dbo")
|
||||
table.Description = "User accounts table"
|
||||
|
||||
col := models.InitColumn("id", "users", "dbo")
|
||||
col.Type = "int"
|
||||
col.Description = "Primary key"
|
||||
col.Sequence = 1
|
||||
table.Columns["id"] = col
|
||||
|
||||
// Write to buffer
|
||||
buf := &bytes.Buffer{}
|
||||
writer.writer = buf
|
||||
|
||||
err := writer.writeComments(schema, table)
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "sp_addextendedproperty")
|
||||
assert.Contains(t, output, "MS_Description")
|
||||
assert.Contains(t, output, "User accounts table")
|
||||
assert.Contains(t, output, "Primary key")
|
||||
}
|
||||
@@ -31,7 +31,7 @@ type MigrationWriter struct {
|
||||
|
||||
// NewMigrationWriter creates a new templated migration writer
|
||||
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(options.FlattenSchema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template executor: %w", err)
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTemplateExecutor_CreateTable(t *testing.T) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create executor: %v", err)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func TestTemplateExecutor_CreateTable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create executor: %v", err)
|
||||
}
|
||||
|
||||
@@ -314,7 +314,7 @@ func TestFormatType(t *testing.T) {
|
||||
|
||||
// Test that template functions work in actual templates
|
||||
func TestTemplateFunctionsInTemplate(t *testing.T) {
|
||||
executor, err := NewTemplateExecutor()
|
||||
executor, err := NewTemplateExecutor(false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create executor: %v", err)
|
||||
}
|
||||
|
||||
@@ -18,14 +18,39 @@ type TemplateExecutor struct {
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// NewTemplateExecutor creates a new template executor
|
||||
func NewTemplateExecutor() (*TemplateExecutor, error) {
|
||||
// NewTemplateExecutor creates a new template executor.
|
||||
// flattenSchema controls whether schema.table identifiers use dot or underscore separation.
|
||||
func NewTemplateExecutor(flattenSchema bool) (*TemplateExecutor, error) {
|
||||
// Create template with custom functions
|
||||
funcMap := make(template.FuncMap)
|
||||
for k, v := range TemplateFunctions() {
|
||||
funcMap[k] = v
|
||||
}
|
||||
|
||||
// qual_table returns a quoted, schema-qualified identifier.
|
||||
// With flatten=false: "schema"."table" (or unquoted equivalents).
|
||||
// With flatten=true: "schema_table".
|
||||
funcMap["qual_table"] = func(schema, name string) string {
|
||||
if schema == "" {
|
||||
return quoteIdent(name)
|
||||
}
|
||||
if flattenSchema {
|
||||
return quoteIdent(schema + "_" + name)
|
||||
}
|
||||
return quoteIdent(schema) + "." + quoteIdent(name)
|
||||
}
|
||||
|
||||
// qual_table_raw is the same as qual_table but without identifier quoting.
|
||||
funcMap["qual_table_raw"] = func(schema, name string) string {
|
||||
if schema == "" {
|
||||
return name
|
||||
}
|
||||
if flattenSchema {
|
||||
return schema + "_" + name
|
||||
}
|
||||
return schema + "." + name
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Funcs(funcMap).ParseFS(templateFS, "templates/*.tmpl")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse templates: %w", err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
|
||||
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
||||
{{- if .NotNull}} NOT NULL{{end}};
|
||||
@@ -6,7 +6,7 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND column_name = '{{.ColumnName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD COLUMN {{.ColumnDefinition}};
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD COLUMN {{.ColumnDefinition}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{{- if .SetDefault -}}
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||
{{- else -}}
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
|
||||
{{- end -}}
|
||||
@@ -1,2 +1,2 @@
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
||||
@@ -1,4 +1,4 @@
|
||||
CREATE OR REPLACE FUNCTION {{.SchemaName}}.{{.FunctionName}}()
|
||||
CREATE OR REPLACE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}()
|
||||
RETURNS trigger AS
|
||||
$body$
|
||||
DECLARE
|
||||
@@ -81,4 +81,4 @@ LANGUAGE plpgsql
|
||||
VOLATILE
|
||||
SECURITY DEFINER;
|
||||
|
||||
COMMENT ON FUNCTION {{.SchemaName}}.{{.FunctionName}}() IS 'Audit trigger function for table {{.SchemaName}}.{{.TableName}}';
|
||||
COMMENT ON FUNCTION {{qual_table_raw .SchemaName .FunctionName}}() IS 'Audit trigger function for table {{qual_table_raw .SchemaName .TableName}}';
|
||||
@@ -4,13 +4,13 @@ BEGIN
|
||||
SELECT 1
|
||||
FROM pg_trigger
|
||||
WHERE tgname = '{{.TriggerName}}'
|
||||
AND tgrelid = '{{.SchemaName}}.{{.TableName}}'::regclass
|
||||
AND tgrelid = '{{qual_table_raw .SchemaName .TableName}}'::regclass
|
||||
) THEN
|
||||
CREATE TRIGGER {{.TriggerName}}
|
||||
AFTER {{.Events}}
|
||||
ON {{.SchemaName}}.{{.TableName}}
|
||||
ON {{qual_table_raw .SchemaName .TableName}}
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {{.SchemaName}}.{{.FunctionName}}();
|
||||
EXECUTE FUNCTION {{qual_table_raw .SchemaName .FunctionName}}();
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,6 +1,6 @@
|
||||
{{/* Base constraint template */}}
|
||||
{{- define "constraint_base" -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}}
|
||||
{{block "constraint_definition" .}}{{end}};
|
||||
{{- end -}}
|
||||
@@ -15,7 +15,7 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||
DROP CONSTRAINT {{.ConstraintName}};
|
||||
END IF;
|
||||
END;
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
{{/* Base ALTER TABLE structure */}}
|
||||
{{- define "alter_table_base" -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER TABLE {{qual_table_raw .SchemaName .TableName}}
|
||||
{{block "alter_operation" .}}{{end}};
|
||||
{{- end -}}
|
||||
|
||||
@@ -30,5 +30,5 @@ $$;
|
||||
|
||||
{{/* Common drop pattern */}}
|
||||
{{- define "drop_if_exists" -}}
|
||||
{{block "drop_type" .}}{{end}} IF EXISTS {{.SchemaName}}.{{.ObjectName}};
|
||||
{{block "drop_type" .}}{{end}} IF EXISTS {{qual_table_raw .SchemaName .ObjectName}};
|
||||
{{- end -}}
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON COLUMN {{quote_ident .SchemaName}}.{{quote_ident .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';
|
||||
COMMENT ON COLUMN {{qual_table .SchemaName .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} IS '{{.Comment}}';
|
||||
COMMENT ON TABLE {{qual_table .SchemaName .TableName}} IS '{{.Comment}}';
|
||||
@@ -6,7 +6,7 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,10 +1,10 @@
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}
|
||||
DEFERRABLE;
|
||||
@@ -6,10 +6,10 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||
REFERENCES {{qual_table .TargetSchema .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
|
||||
DEFERRABLE{{end}};
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
|
||||
ON {{quote_ident .SchemaName}}.{{quote_ident .TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
ON {{qual_table .SchemaName .TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
@@ -6,7 +6,7 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
|
||||
@@ -11,7 +11,7 @@ BEGIN
|
||||
AND constraint_name IN ({{.AutoGenNames}});
|
||||
|
||||
IF auto_pk_name IS NOT NULL THEN
|
||||
EXECUTE 'ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||
EXECUTE 'ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||
END IF;
|
||||
|
||||
-- Add named primary key if it doesn't exist
|
||||
@@ -21,7 +21,7 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
CREATE SEQUENCE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}
|
||||
CREATE SEQUENCE IF NOT EXISTS {{qual_table .SchemaName .SequenceName}}
|
||||
INCREMENT {{.Increment}}
|
||||
MINVALUE {{.MinValue}}
|
||||
MAXVALUE {{.MaxValue}}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
CREATE TABLE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .TableName}} (
|
||||
CREATE TABLE IF NOT EXISTS {{qual_table .SchemaName .TableName}} (
|
||||
{{- range $i, $col := .Columns}}
|
||||
{{- if $i}},{{end}}
|
||||
{{quote_ident $col.Name}} {{$col.Type}}
|
||||
|
||||
@@ -6,7 +6,7 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1 +1 @@
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
ALTER TABLE {{qual_table .SchemaName .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
@@ -1 +1 @@
|
||||
DROP INDEX IF EXISTS {{quote_ident .SchemaName}}.{{quote_ident .IndexName}} CASCADE;
|
||||
DROP INDEX IF EXISTS {{qual_table .SchemaName .IndexName}} CASCADE;
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
{{/* Qualified table name */}}
|
||||
{{- define "qualified_table" -}}
|
||||
{{.SchemaName}}.{{.TableName}}
|
||||
{{qual_table_raw .SchemaName .TableName}}
|
||||
{{- end -}}
|
||||
|
||||
{{/* Index method clause */}}
|
||||
|
||||
@@ -10,10 +10,10 @@ BEGIN
|
||||
AND c.relkind = 'S'
|
||||
) THEN
|
||||
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
|
||||
FROM {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
FROM {{qual_table .SchemaName .TableName}}
|
||||
INTO m_cnt;
|
||||
|
||||
PERFORM setval('{{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}'::regclass, m_cnt);
|
||||
PERFORM setval('{{qual_table_raw .SchemaName .SequenceName}}'::regclass, m_cnt);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -58,13 +58,18 @@ type ExecutionError struct {
|
||||
|
||||
// NewWriter creates a new PostgreSQL SQL writer
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
executor, _ := NewTemplateExecutor()
|
||||
executor, _ := NewTemplateExecutor(options.FlattenSchema)
|
||||
return &Writer{
|
||||
options: options,
|
||||
executor: executor,
|
||||
}
|
||||
}
|
||||
|
||||
// qualTable returns a schema-qualified name using the writer's FlattenSchema setting.
|
||||
func (w *Writer) qualTable(schema, name string) string {
|
||||
return writers.QualifiedTableName(schema, name, w.options.FlattenSchema)
|
||||
}
|
||||
|
||||
// WriteDatabase writes the entire database schema as SQL
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
// Check if we should execute SQL directly on a database
|
||||
@@ -134,8 +139,8 @@ func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, erro
|
||||
func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
// Phase 1: Create schema
|
||||
if schema.Name != "public" {
|
||||
// Phase 1: Create schema (skip entirely when flattening)
|
||||
if schema.Name != "public" && !w.options.FlattenSchema {
|
||||
statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name))
|
||||
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
||||
}
|
||||
@@ -157,8 +162,8 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
continue
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
|
||||
schema.SQLName(), seqName)
|
||||
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
|
||||
w.qualTable(schema.SQLName(), seqName))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
@@ -275,8 +280,8 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
||||
uniqueStr, quoteIdentifier(index.Name), schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s USING %s (%s)%s",
|
||||
uniqueStr, quoteIdentifier(index.Name), w.qualTable(schema.SQLName(), table.SQLName()), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -374,15 +379,15 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
// Phase 7: Comments
|
||||
for _, table := range schema.Tables {
|
||||
if table.Comment != "" {
|
||||
stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'",
|
||||
schema.SQLName(), table.SQLName(), escapeQuote(table.Comment))
|
||||
stmt := fmt.Sprintf("COMMENT ON TABLE %s IS '%s'",
|
||||
w.qualTable(schema.SQLName(), table.SQLName()), escapeQuote(table.Comment))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
for _, column := range table.Columns {
|
||||
if column.Comment != "" {
|
||||
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'",
|
||||
schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment))
|
||||
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'",
|
||||
w.qualTable(schema.SQLName(), table.SQLName()), column.SQLName(), escapeQuote(column.Comment))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -474,8 +479,8 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
||||
columnDefs = append(columnDefs, " "+def)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n%s\n)",
|
||||
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
||||
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n%s\n)",
|
||||
w.qualTable(schema.SQLName(), table.SQLName()), strings.Join(columnDefs, ",\n"))
|
||||
statements = append(statements, stmt)
|
||||
|
||||
return statements, nil
|
||||
@@ -655,8 +660,7 @@ func (w *Writer) WriteAddColumnStatements(db *models.Database) error {
|
||||
|
||||
// writeCreateSchema generates CREATE SCHEMA statement
|
||||
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
||||
if schema.Name == "public" {
|
||||
// public schema exists by default
|
||||
if schema.Name == "public" || w.options.FlattenSchema {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -708,8 +712,8 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name)
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
fmt.Fprintf(w.writer, "CREATE TABLE IF NOT EXISTS %s.%s (\n",
|
||||
schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, "CREATE TABLE IF NOT EXISTS %s (\n",
|
||||
w.qualTable(schema.SQLName(), table.SQLName()))
|
||||
|
||||
// Write columns
|
||||
columns := getSortedColumns(table.Columns)
|
||||
@@ -893,8 +897,8 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
||||
|
||||
fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n",
|
||||
unique, indexName)
|
||||
fmt.Fprintf(w.writer, " ON %s.%s USING %s (%s)%s;\n\n",
|
||||
schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
fmt.Fprintf(w.writer, " ON %s USING %s (%s)%s;\n\n",
|
||||
w.qualTable(schema.SQLName(), table.SQLName()), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1203,16 +1207,16 @@ func (w *Writer) writeComments(schema *models.Schema) error {
|
||||
for _, table := range schema.Tables {
|
||||
// Table comment
|
||||
if table.Description != "" {
|
||||
fmt.Fprintf(w.writer, "COMMENT ON TABLE %s.%s IS '%s';\n",
|
||||
schema.SQLName(), table.SQLName(),
|
||||
fmt.Fprintf(w.writer, "COMMENT ON TABLE %s IS '%s';\n",
|
||||
w.qualTable(schema.SQLName(), table.SQLName()),
|
||||
escapeQuote(table.Description))
|
||||
}
|
||||
|
||||
// Column comments
|
||||
for _, col := range getSortedColumns(table.Columns) {
|
||||
if col.Description != "" {
|
||||
fmt.Fprintf(w.writer, "COMMENT ON COLUMN %s.%s.%s IS '%s';\n",
|
||||
schema.SQLName(), table.SQLName(), col.SQLName(),
|
||||
fmt.Fprintf(w.writer, "COMMENT ON COLUMN %s.%s IS '%s';\n",
|
||||
w.qualTable(schema.SQLName(), table.SQLName()), col.SQLName(),
|
||||
escapeQuote(col.Description))
|
||||
}
|
||||
}
|
||||
|
||||
215
pkg/writers/sqlite/README.md
Normal file
215
pkg/writers/sqlite/README.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# SQLite Writer
|
||||
|
||||
SQLite DDL (Data Definition Language) writer for RelSpec. Converts database schemas to SQLite-compatible SQL statements.
|
||||
|
||||
## Features
|
||||
|
||||
- **Automatic Schema Flattening** - SQLite doesn't support PostgreSQL-style schemas, so table names are automatically flattened (e.g., `public.users` → `public_users`)
|
||||
- **Type Mapping** - Converts PostgreSQL data types to SQLite type affinities (TEXT, INTEGER, REAL, NUMERIC, BLOB)
|
||||
- **Auto-Increment Detection** - Automatically converts SERIAL types and auto-increment columns to `INTEGER PRIMARY KEY AUTOINCREMENT`
|
||||
- **Function Translation** - Converts PostgreSQL functions to SQLite equivalents (e.g., `now()` → `CURRENT_TIMESTAMP`)
|
||||
- **Boolean Handling** - Maps boolean values to INTEGER (true=1, false=0)
|
||||
- **Constraint Generation** - Creates indexes, unique constraints, and documents foreign keys
|
||||
- **Identifier Quoting** - Properly quotes identifiers using double quotes
|
||||
|
||||
## Usage
|
||||
|
||||
### Convert PostgreSQL to SQLite
|
||||
|
||||
```bash
|
||||
relspec convert --from pgsql --from-conn "postgres://user:pass@localhost/mydb" \
|
||||
--to sqlite --to-path schema.sql
|
||||
```
|
||||
|
||||
### Convert DBML to SQLite
|
||||
|
||||
```bash
|
||||
relspec convert --from dbml --from-path schema.dbml \
|
||||
--to sqlite --to-path schema.sql
|
||||
```
|
||||
|
||||
### Multi-Schema Databases
|
||||
|
||||
SQLite doesn't support schemas, so multi-schema databases are automatically flattened:
|
||||
|
||||
```bash
|
||||
# Input has auth.users and public.posts
|
||||
# Output will have auth_users and public_posts
|
||||
relspec convert --from json --from-path multi_schema.json \
|
||||
--to sqlite --to-path flattened.sql
|
||||
```
|
||||
|
||||
## Type Mapping
|
||||
|
||||
| PostgreSQL Type | SQLite Affinity | Examples |
|
||||
|----------------|-----------------|----------|
|
||||
| TEXT | TEXT | varchar, text, char, citext, uuid, timestamp, json |
|
||||
| INTEGER | INTEGER | int, integer, smallint, bigint, serial, boolean |
|
||||
| REAL | REAL | real, float, double precision |
|
||||
| NUMERIC | NUMERIC | numeric, decimal |
|
||||
| BLOB | BLOB | bytea, blob |
|
||||
|
||||
## Auto-Increment Handling
|
||||
|
||||
Columns are converted to `INTEGER PRIMARY KEY AUTOINCREMENT` when they meet these criteria:
|
||||
- Marked as primary key
|
||||
- Integer type
|
||||
- Have `AutoIncrement` flag set, OR
|
||||
- Type contains "serial", OR
|
||||
- Default value contains "nextval"
|
||||
|
||||
**Example:**
|
||||
|
||||
```sql
|
||||
-- Input (PostgreSQL)
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(100)
|
||||
);
|
||||
|
||||
-- Output (SQLite)
|
||||
CREATE TABLE "users" (
|
||||
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
"name" TEXT
|
||||
);
|
||||
```
|
||||
|
||||
## Default Value Translation
|
||||
|
||||
| PostgreSQL | SQLite | Notes |
|
||||
|-----------|--------|-------|
|
||||
| `now()`, `CURRENT_TIMESTAMP` | `CURRENT_TIMESTAMP` | Timestamp functions |
|
||||
| `CURRENT_DATE` | `CURRENT_DATE` | Date function |
|
||||
| `CURRENT_TIME` | `CURRENT_TIME` | Time function |
|
||||
| `true`, `false` | `1`, `0` | Boolean values |
|
||||
| `gen_random_uuid()` | *(removed)* | SQLite has no built-in UUID |
|
||||
| `nextval(...)` | *(removed)* | Handled by AUTOINCREMENT |
|
||||
|
||||
## Foreign Keys
|
||||
|
||||
Foreign keys are generated as commented-out ALTER TABLE statements for reference:
|
||||
|
||||
```sql
|
||||
-- Foreign key: fk_posts_user_id
|
||||
-- ALTER TABLE "posts" ADD CONSTRAINT "posts_fk_posts_user_id"
|
||||
-- FOREIGN KEY ("user_id")
|
||||
-- REFERENCES "users" ("id");
|
||||
-- Note: Foreign keys should be defined in CREATE TABLE for better SQLite compatibility
|
||||
```
|
||||
|
||||
For production use, define foreign keys directly in the CREATE TABLE statement or execute the ALTER TABLE commands after creating all tables.
|
||||
|
||||
## Constraints
|
||||
|
||||
- **Primary Keys**: Inline for auto-increment columns, separate constraint for composite keys
|
||||
- **Unique Constraints**: Converted to `CREATE UNIQUE INDEX` statements
|
||||
- **Check Constraints**: Generated as comments (should be added to CREATE TABLE manually)
|
||||
- **Indexes**: Generated without PostgreSQL-specific features (no GIN, GiST, operator classes)
|
||||
|
||||
## Output Structure
|
||||
|
||||
Generated SQL follows this order:
|
||||
|
||||
1. Header comments
|
||||
2. `PRAGMA foreign_keys = ON;`
|
||||
3. CREATE TABLE statements (sorted by schema, then table)
|
||||
4. CREATE INDEX statements
|
||||
5. CREATE UNIQUE INDEX statements (for unique constraints)
|
||||
6. Check constraint comments
|
||||
7. Foreign key comments
|
||||
|
||||
## Example
|
||||
|
||||
**Input (multi-schema PostgreSQL):**
|
||||
|
||||
```sql
|
||||
CREATE SCHEMA auth;
|
||||
CREATE TABLE auth.users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(50) UNIQUE NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE SCHEMA public;
|
||||
CREATE TABLE public.posts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER REFERENCES auth.users(id),
|
||||
title VARCHAR(200) NOT NULL,
|
||||
published BOOLEAN DEFAULT false
|
||||
);
|
||||
```
|
||||
|
||||
**Output (SQLite with flattened schemas):**
|
||||
|
||||
```sql
|
||||
-- SQLite Database Schema
|
||||
-- Database: mydb
|
||||
-- Generated by RelSpec
|
||||
-- Note: Schema names have been flattened (e.g., public.users -> public_users)
|
||||
|
||||
-- Enable foreign key constraints
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- Schema: auth (flattened into table names)
|
||||
|
||||
CREATE TABLE "auth_users" (
|
||||
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
"username" TEXT NOT NULL,
|
||||
"created_at" TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "auth_users_users_username_key" ON "auth_users" ("username");
|
||||
|
||||
-- Schema: public (flattened into table names)
|
||||
|
||||
CREATE TABLE "public_posts" (
|
||||
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
"user_id" INTEGER NOT NULL,
|
||||
"title" TEXT NOT NULL,
|
||||
"published" INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
-- Foreign key: posts_user_id_fkey
|
||||
-- ALTER TABLE "public_posts" ADD CONSTRAINT "public_posts_posts_user_id_fkey"
|
||||
-- FOREIGN KEY ("user_id")
|
||||
-- REFERENCES "auth_users" ("id");
|
||||
-- Note: Foreign keys should be defined in CREATE TABLE for better SQLite compatibility
|
||||
```
|
||||
|
||||
## Programmatic Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers/sqlite"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create writer (automatically enables schema flattening)
|
||||
writer := sqlite.NewWriter(&writers.WriterOptions{
|
||||
OutputPath: "schema.sql",
|
||||
})
|
||||
|
||||
// Write database schema
|
||||
db := &models.Database{
|
||||
Name: "mydb",
|
||||
Schemas: []*models.Schema{
|
||||
// ... your schema data
|
||||
},
|
||||
}
|
||||
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Schema flattening is **always enabled** for SQLite output (cannot be disabled)
|
||||
- Constraint and index names are prefixed with the flattened table name to avoid collisions
|
||||
- Generated SQL is compatible with SQLite 3.x
|
||||
- Foreign key constraints require `PRAGMA foreign_keys = ON;` to be enforced
|
||||
- For complex schemas, review and test the generated SQL before use in production
|
||||
89
pkg/writers/sqlite/datatypes.go
Normal file
89
pkg/writers/sqlite/datatypes.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SQLite type affinities
|
||||
const (
|
||||
TypeText = "TEXT"
|
||||
TypeInteger = "INTEGER"
|
||||
TypeReal = "REAL"
|
||||
TypeNumeric = "NUMERIC"
|
||||
TypeBlob = "BLOB"
|
||||
)
|
||||
|
||||
// MapPostgreSQLType maps PostgreSQL data types to SQLite type affinities
|
||||
func MapPostgreSQLType(pgType string) string {
|
||||
// Normalize the type
|
||||
normalized := strings.ToLower(strings.TrimSpace(pgType))
|
||||
|
||||
// Remove array notation if present
|
||||
normalized = strings.TrimSuffix(normalized, "[]")
|
||||
|
||||
// Remove precision/scale if present
|
||||
if idx := strings.Index(normalized, "("); idx != -1 {
|
||||
normalized = normalized[:idx]
|
||||
}
|
||||
|
||||
// Map to SQLite type affinity
|
||||
switch normalized {
|
||||
// TEXT affinity
|
||||
case "varchar", "character varying", "text", "char", "character",
|
||||
"citext", "uuid", "timestamp", "timestamptz", "timestamp with time zone",
|
||||
"timestamp without time zone", "date", "time", "timetz", "time with time zone",
|
||||
"time without time zone", "json", "jsonb", "xml", "inet", "cidr", "macaddr":
|
||||
return TypeText
|
||||
|
||||
// INTEGER affinity
|
||||
case "int", "int2", "int4", "int8", "integer", "smallint", "bigint",
|
||||
"serial", "smallserial", "bigserial", "boolean", "bool":
|
||||
return TypeInteger
|
||||
|
||||
// REAL affinity
|
||||
case "real", "float", "float4", "float8", "double precision":
|
||||
return TypeReal
|
||||
|
||||
// NUMERIC affinity
|
||||
case "numeric", "decimal", "money":
|
||||
return TypeNumeric
|
||||
|
||||
// BLOB affinity
|
||||
case "bytea", "blob":
|
||||
return TypeBlob
|
||||
|
||||
default:
|
||||
// Default to TEXT for unknown types
|
||||
return TypeText
|
||||
}
|
||||
}
|
||||
|
||||
// IsIntegerType checks if a column type should be treated as integer
|
||||
func IsIntegerType(colType string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(colType))
|
||||
normalized = strings.TrimSuffix(normalized, "[]")
|
||||
if idx := strings.Index(normalized, "("); idx != -1 {
|
||||
normalized = normalized[:idx]
|
||||
}
|
||||
|
||||
switch normalized {
|
||||
case "int", "int2", "int4", "int8", "integer", "smallint", "bigint",
|
||||
"serial", "smallserial", "bigserial":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// MapBooleanValue converts PostgreSQL boolean literals to SQLite (0/1)
|
||||
func MapBooleanValue(value string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
switch normalized {
|
||||
case "true", "t", "yes", "y", "1":
|
||||
return "1"
|
||||
case "false", "f", "no", "n", "0":
|
||||
return "0"
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
146
pkg/writers/sqlite/template_functions.go
Normal file
146
pkg/writers/sqlite/template_functions.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// GetTemplateFuncs returns template functions for SQLite SQL generation
|
||||
func GetTemplateFuncs(opts *writers.WriterOptions) template.FuncMap {
|
||||
return template.FuncMap{
|
||||
"quote_ident": QuoteIdentifier,
|
||||
"map_type": MapPostgreSQLType,
|
||||
"is_autoincrement": IsAutoIncrementCandidate,
|
||||
"qualified_table_name": func(schema, table string) string {
|
||||
return writers.QualifiedTableName(schema, table, opts.FlattenSchema)
|
||||
},
|
||||
"format_default": FormatDefault,
|
||||
"format_constraint_name": func(schema, table, constraint string) string {
|
||||
return FormatConstraintName(schema, table, constraint, opts)
|
||||
},
|
||||
"join": strings.Join,
|
||||
"lower": strings.ToLower,
|
||||
"upper": strings.ToUpper,
|
||||
}
|
||||
}
|
||||
|
||||
// QuoteIdentifier quotes an identifier for SQLite (double quotes)
|
||||
func QuoteIdentifier(name string) string {
|
||||
// SQLite uses double quotes for identifiers
|
||||
// Escape any existing double quotes by doubling them
|
||||
escaped := strings.ReplaceAll(name, `"`, `""`)
|
||||
return fmt.Sprintf(`"%s"`, escaped)
|
||||
}
|
||||
|
||||
// IsAutoIncrementCandidate checks if a column should use AUTOINCREMENT
|
||||
func IsAutoIncrementCandidate(col *models.Column) bool {
|
||||
// Must be a primary key
|
||||
if !col.IsPrimaryKey {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must be an integer type
|
||||
if !IsIntegerType(col.Type) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check AutoIncrement field
|
||||
if col.AutoIncrement {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if default suggests auto-increment
|
||||
if col.Default != nil {
|
||||
defaultStr, ok := col.Default.(string)
|
||||
if ok {
|
||||
defaultLower := strings.ToLower(defaultStr)
|
||||
if strings.Contains(defaultLower, "nextval") ||
|
||||
strings.Contains(defaultLower, "autoincrement") ||
|
||||
strings.Contains(defaultLower, "auto_increment") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Serial types are auto-increment
|
||||
typeLower := strings.ToLower(col.Type)
|
||||
return strings.Contains(typeLower, "serial")
|
||||
}
|
||||
|
||||
// FormatDefault formats a default value for SQLite
|
||||
func FormatDefault(col *models.Column) string {
|
||||
if col.Default == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip auto-increment defaults (handled by AUTOINCREMENT keyword)
|
||||
if IsAutoIncrementCandidate(col) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Convert to string
|
||||
defaultStr, ok := col.Default.(string)
|
||||
if !ok {
|
||||
// If not a string, convert to string representation
|
||||
defaultStr = fmt.Sprintf("%v", col.Default)
|
||||
}
|
||||
|
||||
if defaultStr == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Convert PostgreSQL-specific functions to SQLite equivalents
|
||||
defaultLower := strings.ToLower(defaultStr)
|
||||
|
||||
// Current timestamp functions
|
||||
if strings.Contains(defaultLower, "current_timestamp") ||
|
||||
strings.Contains(defaultLower, "now()") {
|
||||
return "CURRENT_TIMESTAMP"
|
||||
}
|
||||
|
||||
// Current date
|
||||
if strings.Contains(defaultLower, "current_date") {
|
||||
return "CURRENT_DATE"
|
||||
}
|
||||
|
||||
// Current time
|
||||
if strings.Contains(defaultLower, "current_time") {
|
||||
return "CURRENT_TIME"
|
||||
}
|
||||
|
||||
// Boolean values
|
||||
sqliteType := MapPostgreSQLType(col.Type)
|
||||
if sqliteType == TypeInteger {
|
||||
typeLower := strings.ToLower(col.Type)
|
||||
if strings.Contains(typeLower, "bool") {
|
||||
return MapBooleanValue(defaultStr)
|
||||
}
|
||||
}
|
||||
|
||||
// UUID generation - SQLite doesn't have built-in UUID, comment it out
|
||||
if strings.Contains(defaultLower, "uuid") || strings.Contains(defaultLower, "gen_random_uuid") {
|
||||
return "" // Remove UUID defaults, users must handle this
|
||||
}
|
||||
|
||||
// Remove PostgreSQL-specific casting
|
||||
defaultStr = strings.ReplaceAll(defaultStr, "::text", "")
|
||||
defaultStr = strings.ReplaceAll(defaultStr, "::integer", "")
|
||||
defaultStr = strings.ReplaceAll(defaultStr, "::bigint", "")
|
||||
defaultStr = strings.ReplaceAll(defaultStr, "::boolean", "")
|
||||
|
||||
return defaultStr
|
||||
}
|
||||
|
||||
// FormatConstraintName formats a constraint name with table prefix if flattening
|
||||
func FormatConstraintName(schema, table, constraint string, opts *writers.WriterOptions) string {
|
||||
if opts.FlattenSchema && schema != "" {
|
||||
// Prefix constraint with flattened table name
|
||||
flatTable := writers.QualifiedTableName(schema, table, opts.FlattenSchema)
|
||||
return fmt.Sprintf("%s_%s", flatTable, constraint)
|
||||
}
|
||||
return constraint
|
||||
}
|
||||
174
pkg/writers/sqlite/templates.go
Normal file
174
pkg/writers/sqlite/templates.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"text/template"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
//go:embed templates/*.tmpl
|
||||
var templateFS embed.FS
|
||||
|
||||
// TemplateExecutor manages and executes SQLite SQL templates
|
||||
type TemplateExecutor struct {
|
||||
templates *template.Template
|
||||
options *writers.WriterOptions
|
||||
}
|
||||
|
||||
// NewTemplateExecutor creates a new template executor for SQLite
|
||||
func NewTemplateExecutor(opts *writers.WriterOptions) (*TemplateExecutor, error) {
|
||||
// Create template with SQLite-specific functions
|
||||
funcMap := GetTemplateFuncs(opts)
|
||||
|
||||
tmpl, err := template.New("").Funcs(funcMap).ParseFS(templateFS, "templates/*.tmpl")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse templates: %w", err)
|
||||
}
|
||||
|
||||
return &TemplateExecutor{
|
||||
templates: tmpl,
|
||||
options: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Template data structures
|
||||
|
||||
// TableTemplateData contains data for table template
|
||||
type TableTemplateData struct {
|
||||
Schema string
|
||||
Name string
|
||||
Columns []*models.Column
|
||||
PrimaryKey *models.Constraint
|
||||
}
|
||||
|
||||
// IndexTemplateData contains data for index template
|
||||
type IndexTemplateData struct {
|
||||
Schema string
|
||||
Table string
|
||||
Name string
|
||||
Columns []string
|
||||
}
|
||||
|
||||
// ConstraintTemplateData contains data for constraint templates
|
||||
type ConstraintTemplateData struct {
|
||||
Schema string
|
||||
Table string
|
||||
Name string
|
||||
Columns []string
|
||||
Expression string
|
||||
ForeignSchema string
|
||||
ForeignTable string
|
||||
ForeignColumns []string
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
}
|
||||
|
||||
// Execute methods
|
||||
|
||||
// ExecutePragmaForeignKeys executes the pragma foreign keys template
|
||||
func (te *TemplateExecutor) ExecutePragmaForeignKeys() (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "pragma_foreign_keys.tmpl", nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute pragma_foreign_keys template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateTable executes the create table template
|
||||
func (te *TemplateExecutor) ExecuteCreateTable(data TableTemplateData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_table.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_table template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateIndex executes the create index template
|
||||
func (te *TemplateExecutor) ExecuteCreateIndex(data IndexTemplateData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_index.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_index template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateUniqueConstraint executes the create unique constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data ConstraintTemplateData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateCheckConstraint executes the create check constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data ConstraintTemplateData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateForeignKey executes the create foreign key template
|
||||
func (te *TemplateExecutor) ExecuteCreateForeignKey(data ConstraintTemplateData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_foreign_key template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// Helper functions to build template data from models
|
||||
|
||||
// BuildTableTemplateData builds TableTemplateData from a models.Table
|
||||
func BuildTableTemplateData(schema string, table *models.Table) TableTemplateData {
|
||||
// Get sorted columns
|
||||
columns := make([]*models.Column, 0, len(table.Columns))
|
||||
for _, col := range table.Columns {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
// Find primary key constraint
|
||||
var pk *models.Constraint
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.PrimaryKeyConstraint {
|
||||
pk = constraint
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no explicit primary key constraint, build one from columns with IsPrimaryKey=true
|
||||
if pk == nil {
|
||||
pkCols := []string{}
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
pkCols = append(pkCols, col.Name)
|
||||
}
|
||||
}
|
||||
if len(pkCols) > 0 {
|
||||
pk = &models.Constraint{
|
||||
Name: "pk_" + table.Name,
|
||||
Type: models.PrimaryKeyConstraint,
|
||||
Columns: pkCols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return TableTemplateData{
|
||||
Schema: schema,
|
||||
Name: table.Name,
|
||||
Columns: columns,
|
||||
PrimaryKey: pk,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
-- Check constraint: {{.Name}}
|
||||
-- {{.Expression}}
|
||||
-- Note: SQLite supports CHECK constraints in CREATE TABLE or ALTER TABLE ADD CHECK
|
||||
-- This must be added manually to the table definition above
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user