10 Commits

Author SHA1 Message Date
f2d500f98d feat(merge): 🎉 Add support for constraints and indexes in merge results
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m10s
CI / Lint (push) Successful in -26m33s
CI / Build (push) Successful in -26m40s
Release / Build and Release (push) Successful in -26m23s
Integration Tests / Integration Tests (push) Successful in -25m53s
* Enhance MergeResult to track added constraints and indexes.
* Update merge logic to increment counters for added constraints and indexes.
* Modify GetMergeSummary to include constraints and indexes in the output.
* Add comprehensive tests for merging constraints and indexes.
2026-01-31 21:30:55 +02:00
2ec9991324 feat(merge): 🎉 Add support for merging constraints and indexes
Some checks failed
CI / Test (1.24) (push) Failing after -26m37s
CI / Test (1.25) (push) Successful in -26m8s
CI / Lint (push) Successful in -26m32s
CI / Build (push) Successful in -26m42s
Release / Build and Release (push) Successful in -26m26s
Integration Tests / Integration Tests (push) Successful in -26m3s
* Implement mergeConstraints to handle table constraints
* Implement mergeIndexes to handle table indexes
* Update mergeTables to include constraints and indexes during merge
2026-01-31 21:27:28 +02:00
a3e45c206d feat(writer): 🎉 Enhance SQL execution logging and add statement type detection
All checks were successful
CI / Test (1.24) (push) Successful in -26m21s
CI / Test (1.25) (push) Successful in -26m15s
CI / Build (push) Successful in -26m39s
CI / Lint (push) Successful in -26m29s
Release / Build and Release (push) Successful in -26m28s
Integration Tests / Integration Tests (push) Successful in -26m11s
* Log statement type during execution for better debugging
* Introduce detectStatementType function to categorize SQL statements
* Update unique constraint naming convention in tests
2026-01-31 21:19:48 +02:00
165623bb1d feat(pgsql): Add templates for constraints and sequences
All checks were successful
CI / Test (1.24) (push) Successful in -26m21s
CI / Test (1.25) (push) Successful in -26m13s
CI / Build (push) Successful in -26m39s
CI / Lint (push) Successful in -26m29s
Release / Build and Release (push) Successful in -26m28s
Integration Tests / Integration Tests (push) Successful in -26m10s
* Introduce new templates for creating unique, check, and foreign key constraints with existence checks.
* Add templates for setting sequence values and creating sequences.
* Refactor existing SQL generation logic to utilize new templates for better maintainability and readability.
* Ensure identifiers are properly quoted to handle special characters and reserved keywords.
2026-01-31 21:04:43 +02:00
3c20c3c5d9 feat(writer): 🎉 Add support for check constraints in schema generation
All checks were successful
CI / Test (1.24) (push) Successful in -26m17s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m41s
CI / Lint (push) Successful in -26m32s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Implement check constraints in the schema writer.
* Generate SQL statements to add check constraints if they do not exist.
* Add tests to verify correct generation of check constraints.
2026-01-31 20:42:19 +02:00
a54594e49b feat(writer): 🎉 Add support for unique constraints in schema generation
All checks were successful
CI / Test (1.24) (push) Successful in -26m26s
CI / Test (1.25) (push) Successful in -26m18s
CI / Lint (push) Successful in -26m25s
CI / Build (push) Successful in -26m35s
Release / Build and Release (push) Successful in -26m29s
Integration Tests / Integration Tests (push) Successful in -26m11s
* Implement unique constraint handling in GenerateSchemaStatements
* Add writeUniqueConstraints method for generating SQL statements
* Create unit test for unique constraints in writer_test.go
2026-01-31 20:33:08 +02:00
cafe6a461f feat(scripts): 🎉 Add --ignore-errors flag for script execution
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m38s
CI / Lint (push) Successful in -26m30s
Release / Build and Release (push) Successful in -26m27s
Integration Tests / Integration Tests (push) Successful in -26m10s
- Allow continued execution of scripts even if errors occur.
- Update execution summary to include counts of successful and failed scripts.
- Enhance error handling and reporting for better visibility.
2026-01-31 20:21:22 +02:00
abdb9b4c78 feat(dbml/reader): 🎉 Implement splitIdentifier function for parsing
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m17s
CI / Build (push) Successful in -26m44s
CI / Lint (push) Successful in -26m33s
Integration Tests / Integration Tests (push) Successful in -26m11s
Release / Build and Release (push) Successful in -26m36s
2026-01-31 19:45:24 +02:00
e7a15c8e4f feat(writer): 🎉 Implement add column statements for schema evolution
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m14s
CI / Lint (push) Successful in -26m30s
CI / Build (push) Successful in -26m41s
Release / Build and Release (push) Successful in -26m29s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Add functionality to generate ALTER TABLE ADD COLUMN statements for existing tables.
* Introduce tests for generating and writing add column statements.
* Enhance schema evolution capabilities when new columns are added.
2026-01-31 19:12:00 +02:00
c36b5ede2b feat(writer): 🎉 Enhance primary key handling and add tests
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m11s
CI / Build (push) Successful in -26m43s
CI / Lint (push) Successful in -26m34s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m20s
* Implement checks for existing primary keys before adding new ones.
* Drop auto-generated primary keys if they exist.
* Add tests for primary key existence and column size specifiers.
* Improve type conversion handling for PostgreSQL compatibility.
2026-01-31 18:59:32 +02:00
30 changed files with 2480 additions and 139 deletions

View File

@@ -14,10 +14,11 @@ import (
)
var (
scriptsDir string
scriptsConn string
scriptsSchemaName string
scriptsDBName string
scriptsDir string
scriptsConn string
scriptsSchemaName string
scriptsDBName string
scriptsIgnoreErrors bool
)
var scriptsCmd = &cobra.Command{
@@ -62,7 +63,7 @@ var scriptsExecuteCmd = &cobra.Command{
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).
Execution stops immediately on the first error.
By default, execution stops immediately on the first error. Use --ignore-errors to continue execution.
The directory is scanned recursively for all subdirectories and files matching the patterns:
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
@@ -86,7 +87,12 @@ Examples:
# Execute with SSL disabled
relspec scripts execute --dir ./sql \
--conn "postgres://user:pass@localhost/db?sslmode=disable"`,
--conn "postgres://user:pass@localhost/db?sslmode=disable"
# Continue executing even if errors occur
relspec scripts execute --dir ./migrations \
--conn "postgres://localhost/mydb" \
--ignore-errors`,
RunE: runScriptsExecute,
}
@@ -105,6 +111,7 @@ func init() {
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
scriptsExecuteCmd.Flags().BoolVar(&scriptsIgnoreErrors, "ignore-errors", false, "Continue executing scripts even if errors occur")
err = scriptsExecuteCmd.MarkFlagRequired("dir")
if err != nil {
@@ -250,17 +257,39 @@ func runScriptsExecute(cmd *cobra.Command, args []string) error {
writer := sqlexec.NewWriter(&writers.WriterOptions{
Metadata: map[string]any{
"connection_string": scriptsConn,
"ignore_errors": scriptsIgnoreErrors,
},
})
if err := writer.WriteSchema(schema); err != nil {
fmt.Fprintf(os.Stderr, "\n")
return fmt.Errorf("execution failed: %w", err)
return fmt.Errorf("script execution failed: %w", err)
}
// Get execution results from writer metadata
totalCount := len(schema.Scripts)
successCount := totalCount
failedCount := 0
opts := writer.Options()
if total, exists := opts.Metadata["execution_total"].(int); exists {
totalCount = total
}
if success, exists := opts.Metadata["execution_success"].(int); exists {
successCount = success
}
if failed, exists := opts.Metadata["execution_failed"].(int); exists {
failedCount = failed
}
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
fmt.Fprintf(os.Stderr, "Successfully executed %d script(s)\n\n", len(schema.Scripts))
fmt.Fprintf(os.Stderr, "Total scripts: %d\n", totalCount)
fmt.Fprintf(os.Stderr, "Successful: %d\n", successCount)
if failedCount > 0 {
fmt.Fprintf(os.Stderr, "Failed: %d\n", failedCount)
}
fmt.Fprintf(os.Stderr, "\n")
return nil
}

View File

@@ -12,14 +12,16 @@ import (
// MergeResult represents the result of a merge operation
type MergeResult struct {
SchemasAdded int
TablesAdded int
ColumnsAdded int
RelationsAdded int
DomainsAdded int
EnumsAdded int
ViewsAdded int
SequencesAdded int
SchemasAdded int
TablesAdded int
ColumnsAdded int
ConstraintsAdded int
IndexesAdded int
RelationsAdded int
DomainsAdded int
EnumsAdded int
ViewsAdded int
SequencesAdded int
}
// MergeOptions contains options for merge operations
@@ -120,8 +122,10 @@ func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema,
}
if tgtTable, exists := existingTables[tableName]; exists {
// Table exists, merge its columns
// Table exists, merge its columns, constraints, and indexes
r.mergeColumns(tgtTable, srcTable)
r.mergeConstraints(tgtTable, srcTable)
r.mergeIndexes(tgtTable, srcTable)
} else {
// Table doesn't exist, add it
newTable := cloneTable(srcTable)
@@ -151,6 +155,52 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
}
}
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
// Initialize constraints map if nil
if table.Constraints == nil {
table.Constraints = make(map[string]*models.Constraint)
}
// Create map of existing constraints
existingConstraints := make(map[string]*models.Constraint)
for constName := range table.Constraints {
existingConstraints[constName] = table.Constraints[constName]
}
// Merge constraints
for constName, srcConst := range srcTable.Constraints {
if _, exists := existingConstraints[constName]; !exists {
// Constraint doesn't exist, add it
newConst := cloneConstraint(srcConst)
table.Constraints[constName] = newConst
r.ConstraintsAdded++
}
}
}
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
// Initialize indexes map if nil
if table.Indexes == nil {
table.Indexes = make(map[string]*models.Index)
}
// Create map of existing indexes
existingIndexes := make(map[string]*models.Index)
for idxName := range table.Indexes {
existingIndexes[idxName] = table.Indexes[idxName]
}
// Merge indexes
for idxName, srcIdx := range srcTable.Indexes {
if _, exists := existingIndexes[idxName]; !exists {
// Index doesn't exist, add it
newIdx := cloneIndex(srcIdx)
table.Indexes[idxName] = newIdx
r.IndexesAdded++
}
}
}
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
// Create map of existing views
existingViews := make(map[string]*models.View)
@@ -552,6 +602,8 @@ func GetMergeSummary(result *MergeResult) string {
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
fmt.Sprintf("Tables added: %d", result.TablesAdded),
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
fmt.Sprintf("Views added: %d", result.ViewsAdded),
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
@@ -560,6 +612,7 @@ func GetMergeSummary(result *MergeResult) string {
}
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
result.ConstraintsAdded + result.IndexesAdded +
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
result.RelationsAdded + result.DomainsAdded

617
pkg/merge/merge_test.go Normal file
View File

@@ -0,0 +1,617 @@
package merge
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestMergeDatabases_NilInputs(t *testing.T) {
result := MergeDatabases(nil, nil, nil)
if result == nil {
t.Fatal("Expected non-nil result")
}
if result.SchemasAdded != 0 {
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
}
}
func TestMergeDatabases_NewSchema(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{Name: "auth"},
},
}
result := MergeDatabases(target, source, nil)
if result.SchemasAdded != 1 {
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
}
if len(target.Schemas) != 2 {
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
}
}
func TestMergeDatabases_ExistingSchema(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
result := MergeDatabases(target, source, nil)
if result.SchemasAdded != 0 {
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
}
if len(target.Schemas) != 1 {
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
}
}
func TestMergeTables_NewTable(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "posts",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.TablesAdded != 1 {
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
}
if len(target.Schemas[0].Tables) != 2 {
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
}
}
func TestMergeColumns_NewColumn(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "int"},
},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar"},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ColumnsAdded != 1 {
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
}
if len(target.Schemas[0].Tables[0].Columns) != 2 {
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
}
}
func TestMergeConstraints_NewConstraint(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ConstraintsAdded != 1 {
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
}
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
}
}
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: nil, // Nil map
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ConstraintsAdded != 1 {
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
}
if target.Schemas[0].Tables[0].Constraints == nil {
t.Error("Expected constraints map to be initialized")
}
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
}
}
func TestMergeIndexes_NewIndex(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
}
}
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: nil, // Nil map
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
if target.Schemas[0].Tables[0].Indexes == nil {
t.Error("Expected indexes map to be initialized")
}
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
}
}
func TestMergeOptions_SkipTableNames(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "migrations",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
opts := &MergeOptions{
SkipTableNames: map[string]bool{
"migrations": true,
},
}
result := MergeDatabases(target, source, opts)
if result.TablesAdded != 0 {
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
}
if len(target.Schemas[0].Tables) != 1 {
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
}
}
func TestMergeViews_NewView(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Views: []*models.View{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Views: []*models.View{
{
Name: "user_summary",
Schema: "public",
Definition: "SELECT * FROM users",
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ViewsAdded != 1 {
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
}
if len(target.Schemas[0].Views) != 1 {
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
}
}
func TestMergeEnums_NewEnum(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Enums: []*models.Enum{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Enums: []*models.Enum{
{
Name: "user_role",
Schema: "public",
Values: []string{"admin", "user"},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.EnumsAdded != 1 {
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
}
if len(target.Schemas[0].Enums) != 1 {
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
}
}
func TestMergeDomains_NewDomain(t *testing.T) {
target := &models.Database{
Domains: []*models.Domain{},
}
source := &models.Database{
Domains: []*models.Domain{
{
Name: "auth",
Description: "Authentication domain",
},
},
}
result := MergeDatabases(target, source, nil)
if result.DomainsAdded != 1 {
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
}
if len(target.Domains) != 1 {
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
}
}
func TestMergeRelations_NewRelation(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Relations: []*models.Relationship{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Relations: []*models.Relationship{
{
Name: "fk_posts_user",
Type: models.OneToMany,
FromTable: "posts",
FromColumns: []string{"user_id"},
ToTable: "users",
ToColumns: []string{"id"},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.RelationsAdded != 1 {
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
}
if len(target.Schemas[0].Relations) != 1 {
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
}
}
func TestGetMergeSummary(t *testing.T) {
result := &MergeResult{
SchemasAdded: 1,
TablesAdded: 2,
ColumnsAdded: 5,
ConstraintsAdded: 3,
IndexesAdded: 2,
ViewsAdded: 1,
}
summary := GetMergeSummary(result)
if summary == "" {
t.Error("Expected non-empty summary")
}
if len(summary) < 50 {
t.Errorf("Summary seems too short: %s", summary)
}
}
func TestGetMergeSummary_Nil(t *testing.T) {
summary := GetMergeSummary(nil)
if summary == "" {
t.Error("Expected non-empty summary for nil result")
}
}
func TestComplexMerge(t *testing.T) {
// Target with existing structure
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "int"},
},
Constraints: map[string]*models.Constraint{},
Indexes: map[string]*models.Index{},
},
},
},
},
}
// Source with new columns, constraints, and indexes
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar"},
"guid": {Name: "guid", Type: "uuid"},
},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
"ukey_users_guid": {
Type: models.UniqueConstraint,
Columns: []string{"guid"},
Name: "ukey_users_guid",
},
},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
// Verify counts
if result.ColumnsAdded != 2 {
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
}
if result.ConstraintsAdded != 2 {
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
}
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
// Verify target has merged data
table := target.Schemas[0].Tables[0]
if len(table.Columns) != 3 {
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
}
if len(table.Constraints) != 2 {
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
}
if len(table.Indexes) != 1 {
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
}
// Verify specific constraint
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
t.Error("Expected ukey_users_guid constraint to exist")
}
}

View File

@@ -128,6 +128,46 @@ func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
return db, nil
}
// splitIdentifier splits a dotted identifier while respecting quotes
// Handles cases like: "schema.with.dots"."table"."column"
func splitIdentifier(s string) []string {
var parts []string
var current strings.Builder
inQuote := false
quoteChar := byte(0)
for i := 0; i < len(s); i++ {
ch := s[i]
if !inQuote {
switch ch {
case '"', '\'':
inQuote = true
quoteChar = ch
current.WriteByte(ch)
case '.':
if current.Len() > 0 {
parts = append(parts, current.String())
current.Reset()
}
default:
current.WriteByte(ch)
}
} else {
current.WriteByte(ch)
if ch == quoteChar {
inQuote = false
}
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// stripQuotes removes surrounding quotes and comments from an identifier
func stripQuotes(s string) string {
s = strings.TrimSpace(s)
@@ -409,7 +449,9 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
// Parse Table definition
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
tableName := matches[1]
parts := strings.Split(tableName, ".")
// Strip comments/notes before parsing to avoid dots in notes
tableName = strings.TrimSpace(regexp.MustCompile(`\s*\[.*?\]\s*`).ReplaceAllString(tableName, ""))
parts := splitIdentifier(tableName)
if len(parts) == 2 {
currentSchema = stripQuotes(parts[0])
@@ -561,8 +603,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
column.Default = strings.Trim(defaultVal, "'\"")
} else if attr == "unique" {
// Create a unique constraint
// Clean table name by removing leading underscores to avoid double underscores
cleanTableName := strings.TrimLeft(tableName, "_")
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", columnName),
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
models.UniqueConstraint,
)
uniqueConstraint.Schema = schemaName
@@ -610,8 +654,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
constraint.Table = tableName
constraint.Columns = []string{columnName}
}
// Generate short constraint name based on the column
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
// Generate constraint name based on table and columns
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
}
}
}
@@ -695,7 +739,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Generate name if not provided
if index.Name == "" {
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
prefix := "idx"
if index.Unique {
prefix = "uidx"
}
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
}
return index
@@ -755,10 +803,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
return nil
}
// Generate short constraint name based on the source column
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
if len(fromColumns) > 0 {
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
// Generate constraint name based on table and columns
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
if len(fromColumns) == 0 {
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
}
constraint := models.InitConstraint(
@@ -814,7 +862,7 @@ func (r *Reader) parseTableRef(ref string) (schema, table string, columns []stri
}
// Parse schema, table, and optionally column
parts := strings.Split(strings.TrimSpace(ref), ".")
parts := splitIdentifier(strings.TrimSpace(ref))
if len(parts) == 3 {
// Format: "schema"."table"."column"
schema = stripQuotes(parts[0])

View File

@@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) {
}
}
func TestConstraintNaming(t *testing.T) {
// Test that constraints are named with proper prefixes
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
// Find users table
var usersTable *models.Table
var postsTable *models.Table
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
if table.Name == "users" {
usersTable = table
} else if table.Name == "posts" {
postsTable = table
}
}
}
if usersTable == nil {
t.Fatal("Users table not found")
}
if postsTable == nil {
t.Fatal("Posts table not found")
}
// Test unique constraint naming: ukey_table_column
if _, exists := usersTable.Constraints["ukey_users_email"]; !exists {
t.Error("Expected unique constraint 'ukey_users_email' not found")
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
}
if _, exists := postsTable.Constraints["ukey_posts_slug"]; !exists {
t.Error("Expected unique constraint 'ukey_posts_slug' not found")
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
}
// Test foreign key naming: fk_table_column
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
t.Error("Expected foreign key 'fk_posts_user_id' not found")
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
}
// Test unique index naming: uidx_table_columns
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
t.Error("Expected unique index 'uidx_posts_slug' not found")
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
}
// Test regular index naming: idx_table_columns
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
t.Error("Expected index 'idx_posts_user_id_published' not found")
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
}
}
func getKeys[V any](m map[string]V) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
func TestHasCommentedRefs(t *testing.T) {
// Test with the actual multifile test fixtures
tests := []struct {

View File

@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
t.Error("SQL missing DELETE handling")
}
}
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
// Current database (empty)
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("entity")
current.Schemas = append(current.Schemas, currentSchema)
// Model database (with constraint starting with number)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("entity")
// Create individual_actor_relationship table
table := models.InitTable("individual_actor_relationship", "entity")
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
actorIDCol.Type = "integer"
table.Columns["actor_id"] = actorIDCol
// Add constraint with name starting with number
constraint := &models.Constraint{
Name: "215162_fk_actor",
Type: models.ForeignKeyConstraint,
Columns: []string{"actor_id"},
ReferencedSchema: "entity",
ReferencedTable: "actor",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
OnUpdate: "NO ACTION",
}
table.Constraints["215162_fk_actor"] = constraint
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
// Generate migration
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
err = writer.WriteMigration(model, current)
if err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
t.Logf("Generated migration:\n%s", output)
// Verify constraint name is properly quoted
if !strings.Contains(output, `"215162_fk_actor"`) {
t.Error("Constraint name starting with number should be quoted")
}
// Verify the SQL is syntactically correct (contains required keywords)
if !strings.Contains(output, "ADD CONSTRAINT") {
t.Error("Migration missing ADD CONSTRAINT")
}
if !strings.Contains(output, "FOREIGN KEY") {
t.Error("Migration missing FOREIGN KEY")
}
}

View File

@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
"quote": quote,
"escape": escape,
"safe_identifier": safeIdentifier,
"quote_ident": quoteIdent,
// Type conversion
"goTypeToSQL": goTypeToSQL,
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
return strings.ToLower(safe)
}
// quoteIdent quotes a PostgreSQL identifier if necessary
// Identifiers need quoting if they:
// - Start with a digit
// - Contain special characters
// - Are reserved keywords
// - Contain uppercase letters (to preserve case)
func quoteIdent(s string) string {
if s == "" {
return `""`
}
// Check if quoting is needed
needsQuoting := unicode.IsDigit(rune(s[0]))
// Starts with digit
// Contains uppercase letters or special characters
for _, r := range s {
if unicode.IsUpper(r) {
needsQuoting = true
break
}
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
needsQuoting = true
break
}
}
if needsQuoting {
// Escape double quotes by doubling them
escaped := strings.ReplaceAll(s, `"`, `""`)
return `"` + escaped + `"`
}
return s
}
// Type conversion functions
// goTypeToSQL converts Go type to PostgreSQL type

View File

@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
}
}
func TestQuoteIdent(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"valid_name", "valid_name"},
{"ValidName", `"ValidName"`},
{"123column", `"123column"`},
{"215162_fk_constraint", `"215162_fk_constraint"`},
{"user-id", `"user-id"`},
{"user@domain", `"user@domain"`},
{`"quoted"`, `"""quoted"""`},
{"", `""`},
{"lowercase", "lowercase"},
{"with_underscore", "with_underscore"},
}
for _, tt := range tests {
result := quoteIdent(tt.input)
if result != tt.expected {
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestGoTypeToSQL(t *testing.T) {
tests := []struct {
input string
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
// Check that all expected functions are registered
expectedFuncs := []string{
"upper", "lower", "snake_case", "camelCase",
"indent", "quote", "escape", "safe_identifier",
"indent", "quote", "escape", "safe_identifier", "quote_ident",
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
"first", "last", "filter", "mapFunc", "join_with",
"join",

View File

@@ -177,6 +177,72 @@ type AuditTriggerData struct {
Events string
}
// CreateUniqueConstraintData contains data for create unique constraint template
type CreateUniqueConstraintData struct {
SchemaName string
TableName string
ConstraintName string
Columns string
}
// CreateCheckConstraintData contains data for create check constraint template
type CreateCheckConstraintData struct {
SchemaName string
TableName string
ConstraintName string
Expression string
}
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
type CreateForeignKeyWithCheckData struct {
SchemaName string
TableName string
ConstraintName string
SourceColumns string
TargetSchema string
TargetTable string
TargetColumns string
OnDelete string
OnUpdate string
Deferrable bool
}
// SetSequenceValueData contains data for set sequence value template
type SetSequenceValueData struct {
SchemaName string
TableName string
SequenceName string
ColumnName string
}
// CreateSequenceData contains data for create sequence template
type CreateSequenceData struct {
SchemaName string
SequenceName string
Increment int
MinValue int64
MaxValue int64
StartValue int64
CacheSize int
}
// AddColumnWithCheckData contains data for add column with existence check template
type AddColumnWithCheckData struct {
SchemaName string
TableName string
ColumnName string
ColumnDefinition string
}
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
type CreatePrimaryKeyWithAutoGenCheckData struct {
SchemaName string
TableName string
ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string
}
// Execute methods for each template
// ExecuteCreateTable executes the create table template
@@ -319,6 +385,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
return buf.String(), nil
}
// ExecuteCreateUniqueConstraint executes the create unique constraint template
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateCheckConstraint executes the create check constraint template
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteSetSequenceValue executes the set sequence value template
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateSequence executes the create sequence template
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
}
return buf.String(), nil
}
// ExecuteAddColumnWithCheck executes the add column with check template
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
}
return buf.String(), nil
}
// Helper functions to build template data from models
// BuildCreateTableData builds CreateTableData from a models.Table

View File

@@ -1,4 +1,4 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
{{- if .Default}} DEFAULT {{.Default}}{{end}}
{{- if .NotNull}} NOT NULL{{end}};

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND column_name = '{{.ColumnName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD COLUMN {{.ColumnDefinition}};
END IF;
END;
$$;

View File

@@ -1,7 +1,7 @@
{{- if .SetDefault -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
{{- else -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
{{- end -}}

View File

@@ -1,2 +1,2 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};

View File

@@ -1 +1 @@
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
COMMENT ON COLUMN {{quote_ident .SchemaName}}.{{quote_ident .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';

View File

@@ -1 +1 @@
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
COMMENT ON TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} IS '{{.Comment}}';

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
END IF;
END;
$$;

View File

@@ -1,10 +1,10 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD CONSTRAINT {{.ConstraintName}}
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}}
FOREIGN KEY ({{.SourceColumns}})
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
ON DELETE {{.OnDelete}}
ON UPDATE {{.OnUpdate}}
DEFERRABLE;

View File

@@ -0,0 +1,18 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}}
FOREIGN KEY ({{.SourceColumns}})
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
ON DELETE {{.OnDelete}}
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
DEFERRABLE{{end}};
END IF;
END;
$$;

View File

@@ -1,2 +1,2 @@
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
ON {{quote_ident .SchemaName}}.{{quote_ident .TableName}} USING {{.IndexType}} ({{.Columns}});

View File

@@ -6,8 +6,8 @@ BEGIN
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF;
END;
$$;

View File

@@ -0,0 +1,27 @@
DO $$
DECLARE
auto_pk_name text;
BEGIN
-- Drop auto-generated primary key if it exists
SELECT constraint_name INTO auto_pk_name
FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_type = 'PRIMARY KEY'
AND constraint_name IN ({{.AutoGenNames}});
IF auto_pk_name IS NOT NULL THEN
EXECUTE 'ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
END IF;
-- Add named primary key if it doesn't exist
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF;
END;
$$;

View File

@@ -0,0 +1,6 @@
CREATE SEQUENCE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}
INCREMENT {{.Increment}}
MINVALUE {{.MinValue}}
MAXVALUE {{.MaxValue}}
START {{.StartValue}}
CACHE {{.CacheSize}};

View File

@@ -1,7 +1,7 @@
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
CREATE TABLE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .TableName}} (
{{- range $i, $col := .Columns}}
{{- if $i}},{{end}}
{{$col.Name}} {{$col.Type}}
{{quote_ident $col.Name}} {{$col.Type}}
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
{{- if $col.NotNull}} NOT NULL{{end}}
{{- end}}

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
END IF;
END;
$$;

View File

@@ -1 +1 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};

View File

@@ -1 +1 @@
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
DROP INDEX IF EXISTS {{quote_ident .SchemaName}}.{{quote_ident .IndexName}} CASCADE;

View File

@@ -0,0 +1,19 @@
DO $$
DECLARE
m_cnt bigint;
BEGIN
IF EXISTS (
SELECT 1 FROM pg_class c
INNER JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = '{{.SequenceName}}'
AND n.nspname = '{{.SchemaName}}'
AND c.relkind = 'S'
) THEN
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
FROM {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
INTO m_cnt;
PERFORM setval('{{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}'::regclass, m_cnt);
END IF;
END;
$$;

View File

@@ -22,6 +22,7 @@ type Writer struct {
options *writers.WriterOptions
writer io.Writer
executionReport *ExecutionReport
executor *TemplateExecutor
}
// ExecutionReport tracks the execution status of SQL statements
@@ -57,8 +58,10 @@ type ExecutionError struct {
// NewWriter creates a new PostgreSQL SQL writer
func NewWriter(options *writers.WriterOptions) *Writer {
executor, _ := NewTemplateExecutor()
return &Writer{
options: options,
options: options,
executor: executor,
}
}
@@ -168,6 +171,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
statements = append(statements, stmts...)
}
// Phase 3.5: Add missing columns (for existing tables)
addColStmts, err := w.GenerateAddColumnStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate add column statements: %w", err)
}
statements = append(statements, addColStmts...)
// Phase 4: Primary keys
for _, table := range schema.Tables {
// First check for explicit PrimaryKeyConstraint
@@ -179,27 +189,50 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
}
}
var pkColumns []string
var pkName string
if pkConstraint != nil {
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", "))
statements = append(statements, stmt)
pkColumns = pkConstraint.Columns
pkName = pkConstraint.Name
} else {
// No explicit constraint, check for columns with IsPrimaryKey = true
pkColumns := []string{}
pkCols := []string{}
for _, col := range table.Columns {
if col.IsPrimaryKey {
pkColumns = append(pkColumns, col.SQLName())
pkCols = append(pkCols, col.SQLName())
}
}
if len(pkColumns) > 0 {
if len(pkCols) > 0 {
// Sort for consistent output
sort.Strings(pkColumns)
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
statements = append(statements, stmt)
sort.Strings(pkCols)
pkColumns = pkCols
pkName = fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
}
}
if len(pkColumns) > 0 {
// Auto-generated primary key names to check for and drop
autoGenPKNames := []string{
fmt.Sprintf("%s_pkey", table.Name),
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
}
// Use template to generate primary key statement
data := CreatePrimaryKeyWithAutoGenCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(pkColumns, ", "),
}
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
if err != nil {
return nil, fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
// Phase 5: Indexes
@@ -243,7 +276,53 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
}
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
uniqueStr, quoteIdentifier(index.Name), schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
statements = append(statements, stmt)
}
}
// Phase 5.5: Unique constraints
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type != models.UniqueConstraint {
continue
}
// Use template to generate unique constraint statement
data := CreateUniqueConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Columns: strings.Join(constraint.Columns, ", "),
}
stmt, err := w.executor.ExecuteCreateUniqueConstraint(data)
if err != nil {
return nil, fmt.Errorf("failed to generate unique constraint for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
// Phase 5.7: Check constraints
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type != models.CheckConstraint {
continue
}
// Use template to generate check constraint statement
data := CreateCheckConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Expression: constraint.Expression,
}
stmt, err := w.executor.ExecuteCreateCheckConstraint(data)
if err != nil {
return nil, fmt.Errorf("failed to generate check constraint for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
@@ -270,12 +349,24 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
onUpdate = "NO ACTION"
}
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
schema.SQLName(), table.SQLName(), constraint.Name,
strings.Join(constraint.Columns, ", "),
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
strings.Join(constraint.ReferencedColumns, ", "),
onDelete, onUpdate)
// Use template to generate foreign key statement
data := CreateForeignKeyWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
SourceColumns: strings.Join(constraint.Columns, ", "),
TargetSchema: refSchema,
TargetTable: constraint.ReferencedTable,
TargetColumns: strings.Join(constraint.ReferencedColumns, ", "),
OnDelete: onDelete,
OnUpdate: onUpdate,
Deferrable: false,
}
stmt, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
if err != nil {
return nil, fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
@@ -300,6 +391,67 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
return statements, nil
}
// GenerateAddColumnStatements generates ALTER TABLE ADD COLUMN statements for existing tables
// This is useful for schema evolution when new columns are added to existing tables
func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
statements = append(statements, fmt.Sprintf("-- Add missing columns for schema: %s", schema.Name))
for _, table := range schema.Tables {
// Sort columns by sequence or name for consistent output
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
sort.Slice(columns, func(i, j int) bool {
if columns[i].Sequence != columns[j].Sequence {
return columns[i].Sequence < columns[j].Sequence
}
return columns[i].Name < columns[j].Name
})
for _, col := range columns {
colDef := w.generateColumnDefinition(col)
// Use template to generate add column statement
data := AddColumnWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ColumnName: col.Name,
ColumnDefinition: colDef,
}
stmt, err := w.executor.ExecuteAddColumnWithCheck(data)
if err != nil {
return nil, fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
}
statements = append(statements, stmt)
}
}
return statements, nil
}
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
statements := []string{}
statements = append(statements, "-- Add missing columns to existing tables")
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
statements = append(statements, "-- Generated by RelSpec")
for _, schema := range db.Schemas {
schemaStatements, err := w.GenerateAddColumnStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate add column statements for schema %s: %w", schema.Name, err)
}
statements = append(statements, schemaStatements...)
}
return statements, nil
}
// generateCreateTableStatement generates CREATE TABLE statement
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
statements := []string{}
@@ -322,7 +474,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
columnDefs = append(columnDefs, " "+def)
}
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n%s\n)",
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
statements = append(statements, stmt)
@@ -336,14 +488,25 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
// Type with length/precision - convert to valid PostgreSQL type
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
// Only add size specifiers for types that support them
if col.Length > 0 && col.Precision == 0 {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if col.Precision > 0 {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
if supportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
// Convert text with length to varchar
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
// For types that don't support length (integer, bigint, etc.), ignore the length
} else if col.Precision > 0 {
if supportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
// For types that don't support precision, ignore it
}
parts = append(parts, typeStr)
@@ -396,6 +559,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
// Phase 3.5: Add missing columns (priority 120)
if err := w.writeAddColumns(schema); err != nil {
return err
}
// Phase 4: Create primary keys (priority 160)
if err := w.writePrimaryKeys(schema); err != nil {
return err
@@ -406,6 +574,16 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
// Phase 5.5: Create unique constraints (priority 185)
if err := w.writeUniqueConstraints(schema); err != nil {
return err
}
// Phase 5.7: Create check constraints (priority 190)
if err := w.writeCheckConstraints(schema); err != nil {
return err
}
// Phase 6: Create foreign key constraints (priority 195)
if err := w.writeForeignKeys(schema); err != nil {
return err
@@ -437,6 +615,44 @@ func (w *Writer) WriteTable(table *models.Table) error {
return w.WriteSchema(schema)
}
// WriteAddColumnStatements writes ALTER TABLE ADD COLUMN statements for a database
// This is used for schema evolution/migration when new columns are added
func (w *Writer) WriteAddColumnStatements(db *models.Database) error {
var writer io.Writer
var file *os.File
var err error
// Use existing writer if already set (for testing)
if w.writer != nil {
writer = w.writer
} else if w.options.OutputPath != "" {
// Determine output destination
file, err = os.Create(w.options.OutputPath)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer file.Close()
writer = file
} else {
writer = os.Stdout
}
w.writer = writer
// Generate statements
statements, err := w.GenerateAddColumnsForDatabase(db)
if err != nil {
return err
}
// Write each statement
for _, stmt := range statements {
fmt.Fprintf(w.writer, "%s;\n\n", stmt)
}
return nil
}
// writeCreateSchema generates CREATE SCHEMA statement
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
if schema.Name == "public" {
@@ -465,13 +681,23 @@ func (w *Writer) writeSequences(schema *models.Schema) error {
}
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
schema.SQLName(), seqName)
fmt.Fprintf(w.writer, " INCREMENT 1\n")
fmt.Fprintf(w.writer, " MINVALUE 1\n")
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
fmt.Fprintf(w.writer, " START 1\n")
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
data := CreateSequenceData{
SchemaName: schema.Name,
SequenceName: seqName,
Increment: 1,
MinValue: 1,
MaxValue: 9223372036854775807,
StartValue: 1,
CacheSize: 1,
}
sql, err := w.executor.ExecuteCreateSequence(data)
if err != nil {
return fmt.Errorf("failed to generate create sequence for %s.%s: %w", schema.Name, seqName, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
return nil
@@ -490,15 +716,8 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
columnDefs := make([]string, 0, len(columns))
for _, col := range columns {
colDef := fmt.Sprintf(" %s %s", col.SQLName(), pgsql.ConvertSQLType(col.Type))
// Add default value if present
if col.Default != nil && col.Default != "" {
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
defaultVal := fmt.Sprintf("%v", col.Default)
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
}
// Use generateColumnDefinition to properly handle type, length, precision, and defaults
colDef := " " + w.generateColumnDefinition(col)
columnDefs = append(columnDefs, colDef)
}
@@ -509,6 +728,36 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
return nil
}
// writeAddColumns generates ALTER TABLE ADD COLUMN statements for missing columns
func (w *Writer) writeAddColumns(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Add missing columns for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort columns by sequence or name for consistent output
columns := getSortedColumns(table.Columns)
for _, col := range columns {
colDef := w.generateColumnDefinition(col)
data := AddColumnWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ColumnName: col.Name,
ColumnDefinition: colDef,
}
sql, err := w.executor.ExecuteAddColumnWithCheck(data)
if err != nil {
return fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
}
return nil
}
// writePrimaryKeys generates ALTER TABLE statements for primary keys
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
@@ -550,18 +799,26 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
continue
}
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
pkName, strings.Join(columnNames, ", "))
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
// Auto-generated primary key names to check for and drop
autoGenPKNames := []string{
fmt.Sprintf("%s_pkey", table.Name),
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
}
data := CreatePrimaryKeyWithAutoGenCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(columnNames, ", "),
}
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
if err != nil {
return fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
return nil
@@ -644,6 +901,91 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
return nil
}
// writeUniqueConstraints generates ALTER TABLE statements for unique constraints
func (w *Writer) writeUniqueConstraints(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Unique constraints for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort constraints by name for consistent output
constraintNames := make([]string, 0, len(table.Constraints))
for name, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint {
constraintNames = append(constraintNames, name)
}
}
sort.Strings(constraintNames)
for _, name := range constraintNames {
constraint := table.Constraints[name]
// Build column list
columnExprs := make([]string, 0, len(constraint.Columns))
for _, colName := range constraint.Columns {
if col, ok := table.Columns[colName]; ok {
columnExprs = append(columnExprs, col.SQLName())
}
}
if len(columnExprs) == 0 {
continue
}
sql, err := w.executor.ExecuteCreateUniqueConstraint(CreateUniqueConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Columns: strings.Join(columnExprs, ", "),
})
if err != nil {
return fmt.Errorf("failed to generate unique constraint: %w", err)
}
fmt.Fprintf(w.writer, "%s\n\n", sql)
}
}
return nil
}
// writeCheckConstraints generates ALTER TABLE statements for check constraints
func (w *Writer) writeCheckConstraints(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort constraints by name for consistent output
constraintNames := make([]string, 0, len(table.Constraints))
for name, constraint := range table.Constraints {
if constraint.Type == models.CheckConstraint {
constraintNames = append(constraintNames, name)
}
}
sort.Strings(constraintNames)
for _, name := range constraintNames {
constraint := table.Constraints[name]
// Skip if expression is empty
if constraint.Expression == "" {
continue
}
sql, err := w.executor.ExecuteCreateCheckConstraint(CreateCheckConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Expression: constraint.Expression,
})
if err != nil {
return fmt.Errorf("failed to generate check constraint: %w", err)
}
fmt.Fprintf(w.writer, "%s\n\n", sql)
}
}
return nil
}
// writeForeignKeys generates ALTER TABLE statements for foreign keys
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
@@ -711,13 +1053,6 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
}
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
fmt.Fprintf(w.writer, "\n")
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
refSchema := fkConstraint.ReferencedSchema
if refSchema == "" {
@@ -728,11 +1063,103 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
refTable = rel.ToTable
}
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
refSchema, refTable, strings.Join(targetColumns, ", "))
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n")
// Use template executor to generate foreign key with existence check
data := CreateForeignKeyWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: fkName,
SourceColumns: strings.Join(sourceColumns, ", "),
TargetSchema: refSchema,
TargetTable: refTable,
TargetColumns: strings.Join(targetColumns, ", "),
OnDelete: onDelete,
OnUpdate: onUpdate,
Deferrable: true,
}
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
if err != nil {
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
}
// Also process any foreign key constraints that don't have a relationship
processedConstraints := make(map[string]bool)
for _, rel := range table.Relationships {
fkName := rel.ForeignKey
if fkName == "" {
fkName = rel.Name
}
if fkName != "" {
processedConstraints[fkName] = true
}
}
// Find unprocessed foreign key constraints
constraintNames := make([]string, 0)
for name, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint && !processedConstraints[name] {
constraintNames = append(constraintNames, name)
}
}
sort.Strings(constraintNames)
for _, name := range constraintNames {
constraint := table.Constraints[name]
// Build column lists
sourceColumns := make([]string, 0, len(constraint.Columns))
for _, colName := range constraint.Columns {
if col, ok := table.Columns[colName]; ok {
sourceColumns = append(sourceColumns, col.SQLName())
} else {
sourceColumns = append(sourceColumns, colName)
}
}
targetColumns := make([]string, 0, len(constraint.ReferencedColumns))
for _, colName := range constraint.ReferencedColumns {
targetColumns = append(targetColumns, strings.ToLower(colName))
}
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
continue
}
onDelete := "NO ACTION"
if constraint.OnDelete != "" {
onDelete = strings.ToUpper(constraint.OnDelete)
}
onUpdate := "NO ACTION"
if constraint.OnUpdate != "" {
onUpdate = strings.ToUpper(constraint.OnUpdate)
}
refSchema := constraint.ReferencedSchema
if refSchema == "" {
refSchema = schema.Name
}
refTable := constraint.ReferencedTable
// Use template executor to generate foreign key with existence check
data := CreateForeignKeyWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
SourceColumns: strings.Join(sourceColumns, ", "),
TargetSchema: refSchema,
TargetTable: refTable,
TargetColumns: strings.Join(targetColumns, ", "),
OnDelete: onDelete,
OnUpdate: onUpdate,
Deferrable: false,
}
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
if err != nil {
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
}
}
@@ -751,26 +1178,19 @@ func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
fmt.Fprintf(w.writer, "DO $$\n")
fmt.Fprintf(w.writer, "DECLARE\n")
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
fmt.Fprintf(w.writer, "BEGIN\n")
fmt.Fprintf(w.writer, " IF EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
fmt.Fprintf(w.writer, " \n")
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
schema.SQLName(), seqName)
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n")
fmt.Fprintf(w.writer, "$$;\n\n")
// Use template executor to generate set sequence value statement
data := SetSequenceValueData{
SchemaName: schema.Name,
TableName: table.Name,
SequenceName: seqName,
ColumnName: pk.Name,
}
sql, err := w.executor.ExecuteSetSequenceValue(data)
if err != nil {
return fmt.Errorf("failed to generate set sequence value for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
return nil
@@ -844,6 +1264,44 @@ func isTextType(colType string) bool {
return false
}
// supportsLength checks if a PostgreSQL type supports length specification
func supportsLength(colType string) bool {
lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"}
lowerType := strings.ToLower(colType)
for _, t := range lengthTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// supportsPrecision checks if a PostgreSQL type supports precision/scale specification
func supportsPrecision(colType string) bool {
precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"}
lowerType := strings.ToLower(colType)
for _, t := range precisionTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text")
}
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
func formatStringList(items []string) string {
quoted := make([]string, len(items))
for i, item := range items {
quoted[i] = fmt.Sprintf("'%s'", escapeQuote(item))
}
return strings.Join(quoted, ", ")
}
// extractOperatorClass extracts operator class from index comment/note
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
func extractOperatorClass(comment string) string {
@@ -952,7 +1410,8 @@ func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) erro
continue
}
fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements))
stmtType := detectStatementType(stmtTrimmed)
fmt.Fprintf(os.Stderr, "Executing statement %d/%d [%s]...\n", i+1, len(statements), stmtType)
_, execErr := conn.Exec(ctx, stmt)
if execErr != nil {
@@ -1086,3 +1545,94 @@ func truncateStatement(stmt string) string {
func getCurrentTimestamp() string {
return time.Now().Format("2006-01-02 15:04:05")
}
// detectStatementType detects the type of SQL statement for logging
func detectStatementType(stmt string) string {
upperStmt := strings.ToUpper(stmt)
// Check for DO blocks (used for conditional DDL)
if strings.HasPrefix(upperStmt, "DO $$") || strings.HasPrefix(upperStmt, "DO $") {
// Look inside the DO block for the actual operation
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD CONSTRAINT") {
if strings.Contains(upperStmt, "UNIQUE") {
return "ADD UNIQUE CONSTRAINT"
} else if strings.Contains(upperStmt, "FOREIGN KEY") {
return "ADD FOREIGN KEY"
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
return "ADD PRIMARY KEY"
} else if strings.Contains(upperStmt, "CHECK") {
return "ADD CHECK CONSTRAINT"
}
return "ADD CONSTRAINT"
}
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD COLUMN") {
return "ADD COLUMN"
}
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
return "DROP CONSTRAINT"
}
return "DO BLOCK"
}
// Direct DDL statements
if strings.HasPrefix(upperStmt, "CREATE SCHEMA") {
return "CREATE SCHEMA"
}
if strings.HasPrefix(upperStmt, "CREATE SEQUENCE") {
return "CREATE SEQUENCE"
}
if strings.HasPrefix(upperStmt, "CREATE TABLE") {
return "CREATE TABLE"
}
if strings.HasPrefix(upperStmt, "CREATE INDEX") {
return "CREATE INDEX"
}
if strings.HasPrefix(upperStmt, "CREATE UNIQUE INDEX") {
return "CREATE UNIQUE INDEX"
}
if strings.HasPrefix(upperStmt, "ALTER TABLE") {
if strings.Contains(upperStmt, "ADD CONSTRAINT") {
if strings.Contains(upperStmt, "FOREIGN KEY") {
return "ADD FOREIGN KEY"
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
return "ADD PRIMARY KEY"
} else if strings.Contains(upperStmt, "UNIQUE") {
return "ADD UNIQUE CONSTRAINT"
} else if strings.Contains(upperStmt, "CHECK") {
return "ADD CHECK CONSTRAINT"
}
return "ADD CONSTRAINT"
}
if strings.Contains(upperStmt, "ADD COLUMN") {
return "ADD COLUMN"
}
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
return "DROP CONSTRAINT"
}
if strings.Contains(upperStmt, "ALTER COLUMN") {
return "ALTER COLUMN"
}
return "ALTER TABLE"
}
if strings.HasPrefix(upperStmt, "COMMENT ON TABLE") {
return "COMMENT ON TABLE"
}
if strings.HasPrefix(upperStmt, "COMMENT ON COLUMN") {
return "COMMENT ON COLUMN"
}
if strings.HasPrefix(upperStmt, "DROP TABLE") {
return "DROP TABLE"
}
if strings.HasPrefix(upperStmt, "DROP INDEX") {
return "DROP INDEX"
}
// Default
return "SQL"
}
// quoteIdentifier wraps an identifier in double quotes if necessary
// This is needed for identifiers that start with numbers or contain special characters
func quoteIdentifier(s string) string {
return quoteIdent(s)
}

View File

@@ -164,6 +164,296 @@ func TestWriteForeignKeys(t *testing.T) {
}
}
func TestWriteUniqueConstraints(t *testing.T) {
// Create a test database with unique constraints
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create table with unique constraints
table := models.InitTable("users", "public")
// Add columns
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar(255)"
emailCol.NotNull = true
table.Columns["email"] = emailCol
guidCol := models.InitColumn("guid", "users", "public")
guidCol.Type = "uuid"
guidCol.NotNull = true
table.Columns["guid"] = guidCol
// Add unique constraints
emailConstraint := &models.Constraint{
Name: "uq_email",
Type: models.UniqueConstraint,
Schema: "public",
Table: "users",
Columns: []string{"email"},
}
table.Constraints["uq_email"] = emailConstraint
guidConstraint := &models.Constraint{
Name: "uq_guid",
Type: models.UniqueConstraint,
Schema: "public",
Table: "users",
Columns: []string{"guid"},
}
table.Constraints["uq_guid"] = guidConstraint
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify unique constraints are present
if !strings.Contains(output, "-- Unique constraints for schema: public") {
t.Errorf("Output missing unique constraints header")
}
if !strings.Contains(output, "ADD CONSTRAINT uq_email UNIQUE (email)") {
t.Errorf("Output missing uq_email unique constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT uq_guid UNIQUE (guid)") {
t.Errorf("Output missing uq_guid unique constraint\nFull output:\n%s", output)
}
}
func TestWriteCheckConstraints(t *testing.T) {
// Create a test database with check constraints
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create table with check constraints
table := models.InitTable("products", "public")
// Add columns
priceCol := models.InitColumn("price", "products", "public")
priceCol.Type = "numeric(10,2)"
table.Columns["price"] = priceCol
statusCol := models.InitColumn("status", "products", "public")
statusCol.Type = "varchar(20)"
table.Columns["status"] = statusCol
quantityCol := models.InitColumn("quantity", "products", "public")
quantityCol.Type = "integer"
table.Columns["quantity"] = quantityCol
// Add check constraints
priceConstraint := &models.Constraint{
Name: "ck_price_positive",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "price >= 0",
}
table.Constraints["ck_price_positive"] = priceConstraint
statusConstraint := &models.Constraint{
Name: "ck_status_valid",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "status IN ('active', 'inactive', 'discontinued')",
}
table.Constraints["ck_status_valid"] = statusConstraint
quantityConstraint := &models.Constraint{
Name: "ck_quantity_nonnegative",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "quantity >= 0",
}
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify check constraints are present
if !strings.Contains(output, "-- Check constraints for schema: public") {
t.Errorf("Output missing check constraints header")
}
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
}
}
func TestWriteAllConstraintTypes(t *testing.T) {
// Create a comprehensive test with all constraint types
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create orders table
ordersTable := models.InitTable("orders", "public")
// Add columns
idCol := models.InitColumn("id", "orders", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
ordersTable.Columns["id"] = idCol
userIDCol := models.InitColumn("user_id", "orders", "public")
userIDCol.Type = "integer"
userIDCol.NotNull = true
ordersTable.Columns["user_id"] = userIDCol
orderNumberCol := models.InitColumn("order_number", "orders", "public")
orderNumberCol.Type = "varchar(50)"
orderNumberCol.NotNull = true
ordersTable.Columns["order_number"] = orderNumberCol
totalCol := models.InitColumn("total", "orders", "public")
totalCol.Type = "numeric(10,2)"
ordersTable.Columns["total"] = totalCol
statusCol := models.InitColumn("status", "orders", "public")
statusCol.Type = "varchar(20)"
ordersTable.Columns["status"] = statusCol
// Add primary key constraint
pkConstraint := &models.Constraint{
Name: "pk_orders",
Type: models.PrimaryKeyConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"id"},
}
ordersTable.Constraints["pk_orders"] = pkConstraint
// Add unique constraint
uniqueConstraint := &models.Constraint{
Name: "uq_order_number",
Type: models.UniqueConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"order_number"},
}
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
// Add check constraint
checkConstraint := &models.Constraint{
Name: "ck_total_positive",
Type: models.CheckConstraint,
Schema: "public",
Table: "orders",
Expression: "total > 0",
}
ordersTable.Constraints["ck_total_positive"] = checkConstraint
statusCheckConstraint := &models.Constraint{
Name: "ck_status_valid",
Type: models.CheckConstraint,
Schema: "public",
Table: "orders",
Expression: "status IN ('pending', 'completed', 'cancelled')",
}
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
// Add foreign key constraint (referencing a users table)
fkConstraint := &models.Constraint{
Name: "fk_orders_user",
Type: models.ForeignKeyConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"user_id"},
ReferencedSchema: "public",
ReferencedTable: "users",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
OnUpdate: "CASCADE",
}
ordersTable.Constraints["fk_orders_user"] = fkConstraint
schema.Tables = append(schema.Tables, ordersTable)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify all constraint types are present
expectedConstraints := map[string]string{
"Primary Key": "PRIMARY KEY",
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
"Foreign Key": "FOREIGN KEY",
}
for name, expected := range expectedConstraints {
if !strings.Contains(output, expected) {
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
}
}
// Verify section headers
sections := []string{
"-- Primary keys for schema: public",
"-- Unique constraints for schema: public",
"-- Check constraints for schema: public",
"-- Foreign keys for schema: public",
}
for _, section := range sections {
if !strings.Contains(output, section) {
t.Errorf("Output missing section header: %s", section)
}
}
}
func TestWriteTable(t *testing.T) {
// Create a single table
table := models.InitTable("products", "public")
@@ -305,3 +595,263 @@ func TestTypeConversion(t *testing.T) {
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
}
}
func TestPrimaryKeyExistenceCheck(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
nameCol := models.InitColumn("name", "products", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify our naming convention is used
if !strings.Contains(output, "pk_public_products") {
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
}
// Verify it drops auto-generated primary keys
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
}
// Verify it checks for our specific named constraint before adding it
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
}
}
func TestColumnSizeSpecifiers(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("test_sizes", "public")
// Integer with invalid size specifier - should ignore size
integerCol := models.InitColumn("int_col", "test_sizes", "public")
integerCol.Type = "integer"
integerCol.Length = 32
table.Columns["int_col"] = integerCol
// Bigint with invalid size specifier - should ignore size
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
bigintCol.Type = "bigint"
bigintCol.Length = 64
table.Columns["bigint_col"] = bigintCol
// Smallint with invalid size specifier - should ignore size
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
smallintCol.Type = "smallint"
smallintCol.Length = 16
table.Columns["smallint_col"] = smallintCol
// Text with length - should convert to varchar
textCol := models.InitColumn("text_col", "test_sizes", "public")
textCol.Type = "text"
textCol.Length = 100
table.Columns["text_col"] = textCol
// Varchar with length - should keep varchar with length
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
varcharCol.Type = "varchar"
varcharCol.Length = 50
table.Columns["varchar_col"] = varcharCol
// Decimal with precision and scale - should keep them
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
decimalCol.Type = "decimal"
decimalCol.Precision = 19
decimalCol.Scale = 4
table.Columns["decimal_col"] = decimalCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify invalid size specifiers are NOT present
invalidPatterns := []string{
"integer(32)",
"bigint(64)",
"smallint(16)",
"text(100)",
}
for _, pattern := range invalidPatterns {
if strings.Contains(output, pattern) {
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
}
}
// Verify valid patterns ARE present
validPatterns := []string{
"integer", // without size
"bigint", // without size
"smallint", // without size
"varchar(100)", // text converted to varchar with length
"varchar(50)", // varchar with length
"decimal(19,4)", // decimal with precision and scale
}
for _, pattern := range validPatterns {
if !strings.Contains(output, pattern) {
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
}
}
}
func TestGenerateAddColumnStatements(t *testing.T) {
// Create a test database with tables that have new columns
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with columns
table := models.InitTable("users", "public")
// Existing column
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
idCol.Sequence = 1
table.Columns["id"] = idCol
// New column to be added
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar"
emailCol.Length = 255
emailCol.NotNull = true
emailCol.Sequence = 2
table.Columns["email"] = emailCol
// New column with default
statusCol := models.InitColumn("status", "users", "public")
statusCol.Type = "text"
statusCol.Default = "active"
statusCol.Sequence = 3
table.Columns["status"] = statusCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer
options := &writers.WriterOptions{}
writer := NewWriter(options)
// Generate ADD COLUMN statements
statements, err := writer.GenerateAddColumnsForDatabase(db)
if err != nil {
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
}
// Join all statements to verify content
output := strings.Join(statements, "\n")
t.Logf("Generated ADD COLUMN statements:\n%s", output)
// Verify expected elements
expectedStrings := []string{
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
"information_schema.columns",
"table_schema = 'public'",
"table_name = 'users'",
"column_name = 'id'",
"column_name = 'email'",
"column_name = 'status'",
}
for _, expected := range expectedStrings {
if !strings.Contains(output, expected) {
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
}
}
// Verify DO blocks are present for conditional adds
doBlockCount := strings.Count(output, "DO $$")
if doBlockCount < 3 {
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
}
// Verify IF NOT EXISTS logic
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
if ifNotExistsCount < 3 {
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
}
}
func TestWriteAddColumnStatements(t *testing.T) {
// Create a test database
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with a new column to be added
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
table.Columns["id"] = idCol
// New column with various properties
descCol := models.InitColumn("description", "products", "public")
descCol.Type = "text"
descCol.NotNull = false
table.Columns["description"] = descCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write ADD COLUMN statements
err := writer.WriteAddColumnStatements(db)
if err != nil {
t.Fatalf("WriteAddColumnStatements failed: %v", err)
}
output := buf.String()
t.Logf("Generated output:\n%s", output)
// Verify output contains expected elements
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
}
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
}
if !strings.Contains(output, "DO $$") {
t.Errorf("Output missing DO block\nFull output:\n%s", output)
}
}

View File

@@ -23,6 +23,11 @@ func NewWriter(options *writers.WriterOptions) *Writer {
}
}
// Options returns the writer options (useful for reading execution results)
func (w *Writer) Options() *writers.WriterOptions {
return w.options
}
// WriteDatabase executes all scripts from all schemas in the database
func (w *Writer) WriteDatabase(db *models.Database) error {
if db == nil {
@@ -92,6 +97,22 @@ func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*
return nil
}
// Check if we should ignore errors
ignoreErrors := false
if val, ok := w.options.Metadata["ignore_errors"].(bool); ok {
ignoreErrors = val
}
// Track failed scripts and execution counts
var failedScripts []struct {
name string
priority int
sequence uint
err error
}
successCount := 0
totalCount := 0
// Sort scripts by Priority (ascending), Sequence (ascending), then Name (ascending)
sortedScripts := make([]*models.Script, len(scripts))
copy(sortedScripts, scripts)
@@ -111,18 +132,49 @@ func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*
continue
}
totalCount++
fmt.Printf("Executing script: %s (Priority=%d, Sequence=%d)\n",
script.Name, script.Priority, script.Sequence)
// Execute the SQL script
_, err := conn.Exec(ctx, script.SQL)
if err != nil {
return fmt.Errorf("failed to execute script %s (Priority=%d, Sequence=%d): %w",
if ignoreErrors {
fmt.Printf("⚠ Error executing %s: %v (continuing due to --ignore-errors)\n", script.Name, err)
failedScripts = append(failedScripts, struct {
name string
priority int
sequence uint
err error
}{
name: script.Name,
priority: script.Priority,
sequence: script.Sequence,
err: err,
})
continue
}
return fmt.Errorf("script %s (Priority=%d, Sequence=%d): %w",
script.Name, script.Priority, script.Sequence, err)
}
successCount++
fmt.Printf("✓ Successfully executed: %s\n", script.Name)
}
// Store execution results in metadata for caller
w.options.Metadata["execution_total"] = totalCount
w.options.Metadata["execution_success"] = successCount
w.options.Metadata["execution_failed"] = len(failedScripts)
// Print summary of failed scripts if any
if len(failedScripts) > 0 {
fmt.Printf("\n⚠ Failed Scripts Summary (%d failed):\n", len(failedScripts))
for i, failed := range failedScripts {
fmt.Printf(" %d. %s (Priority=%d, Sequence=%d)\n Error: %v\n",
i+1, failed.name, failed.priority, failed.sequence, failed.err)
}
}
return nil
}