Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f2d500f98d | |||
| 2ec9991324 | |||
| a3e45c206d | |||
| 165623bb1d | |||
| 3c20c3c5d9 | |||
| a54594e49b | |||
| cafe6a461f | |||
| abdb9b4c78 | |||
| e7a15c8e4f | |||
| c36b5ede2b | |||
| 51ab29f8e3 | |||
| f532fc110c | |||
| 92dff99725 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -47,3 +47,4 @@ dist/
|
|||||||
build/
|
build/
|
||||||
bin/
|
bin/
|
||||||
tests/integration/failed_statements_example.txt
|
tests/integration/failed_statements_example.txt
|
||||||
|
test_output.log
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ var (
|
|||||||
scriptsConn string
|
scriptsConn string
|
||||||
scriptsSchemaName string
|
scriptsSchemaName string
|
||||||
scriptsDBName string
|
scriptsDBName string
|
||||||
|
scriptsIgnoreErrors bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var scriptsCmd = &cobra.Command{
|
var scriptsCmd = &cobra.Command{
|
||||||
@@ -39,8 +40,8 @@ Example filenames (hyphen format):
|
|||||||
1-002-create-posts.sql # Priority 1, Sequence 2
|
1-002-create-posts.sql # Priority 1, Sequence 2
|
||||||
10-10-create-newid.pgsql # Priority 10, Sequence 10
|
10-10-create-newid.pgsql # Priority 10, Sequence 10
|
||||||
|
|
||||||
Both formats can be mixed in the same directory.
|
Both formats can be mixed in the same directory and subdirectories.
|
||||||
Scripts are executed in order: Priority (ascending), then Sequence (ascending).`,
|
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).`,
|
||||||
}
|
}
|
||||||
|
|
||||||
var scriptsListCmd = &cobra.Command{
|
var scriptsListCmd = &cobra.Command{
|
||||||
@@ -48,8 +49,8 @@ var scriptsListCmd = &cobra.Command{
|
|||||||
Short: "List SQL scripts from a directory",
|
Short: "List SQL scripts from a directory",
|
||||||
Long: `List SQL scripts from a directory and show their execution order.
|
Long: `List SQL scripts from a directory and show their execution order.
|
||||||
|
|
||||||
The scripts are read from the specified directory and displayed in the order
|
The scripts are read recursively from the specified directory and displayed in the order
|
||||||
they would be executed (Priority ascending, then Sequence ascending).
|
they would be executed: Priority (ascending), then Sequence (ascending), then Name (alphabetical).
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
relspec scripts list --dir ./migrations`,
|
relspec scripts list --dir ./migrations`,
|
||||||
@@ -61,10 +62,10 @@ var scriptsExecuteCmd = &cobra.Command{
|
|||||||
Short: "Execute SQL scripts against a database",
|
Short: "Execute SQL scripts against a database",
|
||||||
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
|
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
|
||||||
|
|
||||||
Scripts are executed in order: Priority (ascending), then Sequence (ascending).
|
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).
|
||||||
Execution stops immediately on the first error.
|
By default, execution stops immediately on the first error. Use --ignore-errors to continue execution.
|
||||||
|
|
||||||
The directory is scanned recursively for files matching the patterns:
|
The directory is scanned recursively for all subdirectories and files matching the patterns:
|
||||||
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
|
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
|
||||||
{priority}-{sequence}-{name}.sql or .pgsql (hyphen format)
|
{priority}-{sequence}-{name}.sql or .pgsql (hyphen format)
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ PostgreSQL Connection String Examples:
|
|||||||
postgresql://user:pass@host/dbname?sslmode=require
|
postgresql://user:pass@host/dbname?sslmode=require
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Execute migration scripts
|
# Execute migration scripts from a directory (including subdirectories)
|
||||||
relspec scripts execute --dir ./migrations \
|
relspec scripts execute --dir ./migrations \
|
||||||
--conn "postgres://user:pass@localhost:5432/mydb"
|
--conn "postgres://user:pass@localhost:5432/mydb"
|
||||||
|
|
||||||
@@ -86,7 +87,12 @@ Examples:
|
|||||||
|
|
||||||
# Execute with SSL disabled
|
# Execute with SSL disabled
|
||||||
relspec scripts execute --dir ./sql \
|
relspec scripts execute --dir ./sql \
|
||||||
--conn "postgres://user:pass@localhost/db?sslmode=disable"`,
|
--conn "postgres://user:pass@localhost/db?sslmode=disable"
|
||||||
|
|
||||||
|
# Continue executing even if errors occur
|
||||||
|
relspec scripts execute --dir ./migrations \
|
||||||
|
--conn "postgres://localhost/mydb" \
|
||||||
|
--ignore-errors`,
|
||||||
RunE: runScriptsExecute,
|
RunE: runScriptsExecute,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,6 +111,7 @@ func init() {
|
|||||||
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
|
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
|
||||||
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
|
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
|
||||||
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
|
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
|
||||||
|
scriptsExecuteCmd.Flags().BoolVar(&scriptsIgnoreErrors, "ignore-errors", false, "Continue executing scripts even if errors occur")
|
||||||
|
|
||||||
err = scriptsExecuteCmd.MarkFlagRequired("dir")
|
err = scriptsExecuteCmd.MarkFlagRequired("dir")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -149,7 +156,7 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort scripts by Priority then Sequence
|
// Sort scripts by Priority, Sequence, then Name
|
||||||
sortedScripts := make([]*struct {
|
sortedScripts := make([]*struct {
|
||||||
name string
|
name string
|
||||||
priority int
|
priority int
|
||||||
@@ -186,7 +193,10 @@ func runScriptsList(cmd *cobra.Command, args []string) error {
|
|||||||
if sortedScripts[i].priority != sortedScripts[j].priority {
|
if sortedScripts[i].priority != sortedScripts[j].priority {
|
||||||
return sortedScripts[i].priority < sortedScripts[j].priority
|
return sortedScripts[i].priority < sortedScripts[j].priority
|
||||||
}
|
}
|
||||||
|
if sortedScripts[i].sequence != sortedScripts[j].sequence {
|
||||||
return sortedScripts[i].sequence < sortedScripts[j].sequence
|
return sortedScripts[i].sequence < sortedScripts[j].sequence
|
||||||
|
}
|
||||||
|
return sortedScripts[i].name < sortedScripts[j].name
|
||||||
})
|
})
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Found %d script(s) in execution order:\n\n", len(sortedScripts))
|
fmt.Fprintf(os.Stderr, "Found %d script(s) in execution order:\n\n", len(sortedScripts))
|
||||||
@@ -242,22 +252,44 @@ func runScriptsExecute(cmd *cobra.Command, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, " ✓ Found %d script(s)\n\n", len(schema.Scripts))
|
fmt.Fprintf(os.Stderr, " ✓ Found %d script(s)\n\n", len(schema.Scripts))
|
||||||
|
|
||||||
// Step 2: Execute scripts
|
// Step 2: Execute scripts
|
||||||
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence)...\n\n")
|
fmt.Fprintf(os.Stderr, "[2/2] Executing scripts in order (Priority → Sequence → Name)...\n\n")
|
||||||
|
|
||||||
writer := sqlexec.NewWriter(&writers.WriterOptions{
|
writer := sqlexec.NewWriter(&writers.WriterOptions{
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"connection_string": scriptsConn,
|
"connection_string": scriptsConn,
|
||||||
|
"ignore_errors": scriptsIgnoreErrors,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := writer.WriteSchema(schema); err != nil {
|
if err := writer.WriteSchema(schema); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "\n")
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
return fmt.Errorf("execution failed: %w", err)
|
return fmt.Errorf("script execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get execution results from writer metadata
|
||||||
|
totalCount := len(schema.Scripts)
|
||||||
|
successCount := totalCount
|
||||||
|
failedCount := 0
|
||||||
|
|
||||||
|
opts := writer.Options()
|
||||||
|
if total, exists := opts.Metadata["execution_total"].(int); exists {
|
||||||
|
totalCount = total
|
||||||
|
}
|
||||||
|
if success, exists := opts.Metadata["execution_success"].(int); exists {
|
||||||
|
successCount = success
|
||||||
|
}
|
||||||
|
if failed, exists := opts.Metadata["execution_failed"].(int); exists {
|
||||||
|
failedCount = failed
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
|
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
|
||||||
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
|
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
|
||||||
fmt.Fprintf(os.Stderr, "Successfully executed %d script(s)\n\n", len(schema.Scripts))
|
fmt.Fprintf(os.Stderr, "Total scripts: %d\n", totalCount)
|
||||||
|
fmt.Fprintf(os.Stderr, "Successful: %d\n", successCount)
|
||||||
|
if failedCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed: %d\n", failedCount)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ type MergeResult struct {
|
|||||||
SchemasAdded int
|
SchemasAdded int
|
||||||
TablesAdded int
|
TablesAdded int
|
||||||
ColumnsAdded int
|
ColumnsAdded int
|
||||||
|
ConstraintsAdded int
|
||||||
|
IndexesAdded int
|
||||||
RelationsAdded int
|
RelationsAdded int
|
||||||
DomainsAdded int
|
DomainsAdded int
|
||||||
EnumsAdded int
|
EnumsAdded int
|
||||||
@@ -120,8 +122,10 @@ func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tgtTable, exists := existingTables[tableName]; exists {
|
if tgtTable, exists := existingTables[tableName]; exists {
|
||||||
// Table exists, merge its columns
|
// Table exists, merge its columns, constraints, and indexes
|
||||||
r.mergeColumns(tgtTable, srcTable)
|
r.mergeColumns(tgtTable, srcTable)
|
||||||
|
r.mergeConstraints(tgtTable, srcTable)
|
||||||
|
r.mergeIndexes(tgtTable, srcTable)
|
||||||
} else {
|
} else {
|
||||||
// Table doesn't exist, add it
|
// Table doesn't exist, add it
|
||||||
newTable := cloneTable(srcTable)
|
newTable := cloneTable(srcTable)
|
||||||
@@ -151,6 +155,52 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
|
||||||
|
// Initialize constraints map if nil
|
||||||
|
if table.Constraints == nil {
|
||||||
|
table.Constraints = make(map[string]*models.Constraint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create map of existing constraints
|
||||||
|
existingConstraints := make(map[string]*models.Constraint)
|
||||||
|
for constName := range table.Constraints {
|
||||||
|
existingConstraints[constName] = table.Constraints[constName]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge constraints
|
||||||
|
for constName, srcConst := range srcTable.Constraints {
|
||||||
|
if _, exists := existingConstraints[constName]; !exists {
|
||||||
|
// Constraint doesn't exist, add it
|
||||||
|
newConst := cloneConstraint(srcConst)
|
||||||
|
table.Constraints[constName] = newConst
|
||||||
|
r.ConstraintsAdded++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
|
||||||
|
// Initialize indexes map if nil
|
||||||
|
if table.Indexes == nil {
|
||||||
|
table.Indexes = make(map[string]*models.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create map of existing indexes
|
||||||
|
existingIndexes := make(map[string]*models.Index)
|
||||||
|
for idxName := range table.Indexes {
|
||||||
|
existingIndexes[idxName] = table.Indexes[idxName]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge indexes
|
||||||
|
for idxName, srcIdx := range srcTable.Indexes {
|
||||||
|
if _, exists := existingIndexes[idxName]; !exists {
|
||||||
|
// Index doesn't exist, add it
|
||||||
|
newIdx := cloneIndex(srcIdx)
|
||||||
|
table.Indexes[idxName] = newIdx
|
||||||
|
r.IndexesAdded++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
||||||
// Create map of existing views
|
// Create map of existing views
|
||||||
existingViews := make(map[string]*models.View)
|
existingViews := make(map[string]*models.View)
|
||||||
@@ -552,6 +602,8 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
|
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
|
||||||
fmt.Sprintf("Tables added: %d", result.TablesAdded),
|
fmt.Sprintf("Tables added: %d", result.TablesAdded),
|
||||||
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
|
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
|
||||||
|
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
|
||||||
|
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
|
||||||
fmt.Sprintf("Views added: %d", result.ViewsAdded),
|
fmt.Sprintf("Views added: %d", result.ViewsAdded),
|
||||||
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
|
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
|
||||||
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
||||||
@@ -560,6 +612,7 @@ func GetMergeSummary(result *MergeResult) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
||||||
|
result.ConstraintsAdded + result.IndexesAdded +
|
||||||
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
|
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
|
||||||
result.RelationsAdded + result.DomainsAdded
|
result.RelationsAdded + result.DomainsAdded
|
||||||
|
|
||||||
|
|||||||
617
pkg/merge/merge_test.go
Normal file
617
pkg/merge/merge_test.go
Normal file
@@ -0,0 +1,617 @@
|
|||||||
|
package merge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeDatabases_NilInputs(t *testing.T) {
|
||||||
|
result := MergeDatabases(nil, nil, nil)
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("Expected non-nil result")
|
||||||
|
}
|
||||||
|
if result.SchemasAdded != 0 {
|
||||||
|
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeDatabases_NewSchema(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "auth"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.SchemasAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas) != 2 {
|
||||||
|
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeDatabases_ExistingSchema(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.SchemasAdded != 0 {
|
||||||
|
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas) != 1 {
|
||||||
|
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeTables_NewTable(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "posts",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.TablesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables) != 2 {
|
||||||
|
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeColumns_NewColumn(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "int"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "varchar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ColumnsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: map[string]*models.Constraint{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"ukey_users_email": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"email"},
|
||||||
|
Name: "ukey_users_email",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ConstraintsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: nil, // Nil map
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"ukey_users_email": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"email"},
|
||||||
|
Name: "ukey_users_email",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ConstraintsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||||
|
}
|
||||||
|
if target.Schemas[0].Tables[0].Constraints == nil {
|
||||||
|
t.Error("Expected constraints map to be initialized")
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||||
|
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeIndexes_NewIndex(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: map[string]*models.Index{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_users_email": {
|
||||||
|
Name: "idx_users_email",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.IndexesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||||
|
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: nil, // Nil map
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_users_email": {
|
||||||
|
Name: "idx_users_email",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.IndexesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||||
|
}
|
||||||
|
if target.Schemas[0].Tables[0].Indexes == nil {
|
||||||
|
t.Error("Expected indexes map to be initialized")
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||||
|
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeOptions_SkipTableNames(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "migrations",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &MergeOptions{
|
||||||
|
SkipTableNames: map[string]bool{
|
||||||
|
"migrations": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, opts)
|
||||||
|
if result.TablesAdded != 0 {
|
||||||
|
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Tables) != 1 {
|
||||||
|
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeViews_NewView(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Views: []*models.View{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Views: []*models.View{
|
||||||
|
{
|
||||||
|
Name: "user_summary",
|
||||||
|
Schema: "public",
|
||||||
|
Definition: "SELECT * FROM users",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.ViewsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Views) != 1 {
|
||||||
|
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeEnums_NewEnum(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Enums: []*models.Enum{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Enums: []*models.Enum{
|
||||||
|
{
|
||||||
|
Name: "user_role",
|
||||||
|
Schema: "public",
|
||||||
|
Values: []string{"admin", "user"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.EnumsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Enums) != 1 {
|
||||||
|
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeDomains_NewDomain(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Domains: []*models.Domain{},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Domains: []*models.Domain{
|
||||||
|
{
|
||||||
|
Name: "auth",
|
||||||
|
Description: "Authentication domain",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.DomainsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Domains) != 1 {
|
||||||
|
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeRelations_NewRelation(t *testing.T) {
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Relations: []*models.Relationship{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Relations: []*models.Relationship{
|
||||||
|
{
|
||||||
|
Name: "fk_posts_user",
|
||||||
|
Type: models.OneToMany,
|
||||||
|
FromTable: "posts",
|
||||||
|
FromColumns: []string{"user_id"},
|
||||||
|
ToTable: "users",
|
||||||
|
ToColumns: []string{"id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
if result.RelationsAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
|
||||||
|
}
|
||||||
|
if len(target.Schemas[0].Relations) != 1 {
|
||||||
|
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMergeSummary(t *testing.T) {
|
||||||
|
result := &MergeResult{
|
||||||
|
SchemasAdded: 1,
|
||||||
|
TablesAdded: 2,
|
||||||
|
ColumnsAdded: 5,
|
||||||
|
ConstraintsAdded: 3,
|
||||||
|
IndexesAdded: 2,
|
||||||
|
ViewsAdded: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := GetMergeSummary(result)
|
||||||
|
if summary == "" {
|
||||||
|
t.Error("Expected non-empty summary")
|
||||||
|
}
|
||||||
|
if len(summary) < 50 {
|
||||||
|
t.Errorf("Summary seems too short: %s", summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMergeSummary_Nil(t *testing.T) {
|
||||||
|
summary := GetMergeSummary(nil)
|
||||||
|
if summary == "" {
|
||||||
|
t.Error("Expected non-empty summary for nil result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplexMerge(t *testing.T) {
|
||||||
|
// Target with existing structure
|
||||||
|
target := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "int"},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{},
|
||||||
|
Indexes: map[string]*models.Index{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source with new columns, constraints, and indexes
|
||||||
|
source := &models.Database{
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"email": {Name: "email", Type: "varchar"},
|
||||||
|
"guid": {Name: "guid", Type: "uuid"},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"ukey_users_email": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"email"},
|
||||||
|
Name: "ukey_users_email",
|
||||||
|
},
|
||||||
|
"ukey_users_guid": {
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Columns: []string{"guid"},
|
||||||
|
Name: "ukey_users_guid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_users_email": {
|
||||||
|
Name: "idx_users_email",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := MergeDatabases(target, source, nil)
|
||||||
|
|
||||||
|
// Verify counts
|
||||||
|
if result.ColumnsAdded != 2 {
|
||||||
|
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
|
||||||
|
}
|
||||||
|
if result.ConstraintsAdded != 2 {
|
||||||
|
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
|
||||||
|
}
|
||||||
|
if result.IndexesAdded != 1 {
|
||||||
|
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify target has merged data
|
||||||
|
table := target.Schemas[0].Tables[0]
|
||||||
|
if len(table.Columns) != 3 {
|
||||||
|
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
|
||||||
|
}
|
||||||
|
if len(table.Constraints) != 2 {
|
||||||
|
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
|
||||||
|
}
|
||||||
|
if len(table.Indexes) != 1 {
|
||||||
|
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify specific constraint
|
||||||
|
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
|
||||||
|
t.Error("Expected ukey_users_guid constraint to exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,31 +4,31 @@ import "strings"
|
|||||||
|
|
||||||
var GoToStdTypes = map[string]string{
|
var GoToStdTypes = map[string]string{
|
||||||
"bool": "boolean",
|
"bool": "boolean",
|
||||||
"int64": "integer",
|
"int64": "bigint",
|
||||||
"int": "integer",
|
"int": "integer",
|
||||||
"int8": "integer",
|
"int8": "smallint",
|
||||||
"int16": "integer",
|
"int16": "smallint",
|
||||||
"int32": "integer",
|
"int32": "integer",
|
||||||
"uint": "integer",
|
"uint": "integer",
|
||||||
"uint8": "integer",
|
"uint8": "smallint",
|
||||||
"uint16": "integer",
|
"uint16": "smallint",
|
||||||
"uint32": "integer",
|
"uint32": "integer",
|
||||||
"uint64": "integer",
|
"uint64": "bigint",
|
||||||
"uintptr": "integer",
|
"uintptr": "bigint",
|
||||||
"znullint64": "integer",
|
"znullint64": "bigint",
|
||||||
"znullint32": "integer",
|
"znullint32": "integer",
|
||||||
"znullbyte": "integer",
|
"znullbyte": "smallint",
|
||||||
"float64": "double",
|
"float64": "double",
|
||||||
"float32": "double",
|
"float32": "double",
|
||||||
"complex64": "double",
|
"complex64": "double",
|
||||||
"complex128": "double",
|
"complex128": "double",
|
||||||
"customfloat64": "double",
|
"customfloat64": "double",
|
||||||
"string": "string",
|
"string": "text",
|
||||||
"Pointer": "integer",
|
"Pointer": "bigint",
|
||||||
"[]byte": "blob",
|
"[]byte": "blob",
|
||||||
"customdate": "string",
|
"customdate": "date",
|
||||||
"customtime": "string",
|
"customtime": "time",
|
||||||
"customtimestamp": "string",
|
"customtimestamp": "timestamp",
|
||||||
"sqlfloat64": "double",
|
"sqlfloat64": "double",
|
||||||
"sqlfloat16": "double",
|
"sqlfloat16": "double",
|
||||||
"sqluuid": "uuid",
|
"sqluuid": "uuid",
|
||||||
@@ -36,9 +36,9 @@ var GoToStdTypes = map[string]string{
|
|||||||
"sqljson": "json",
|
"sqljson": "json",
|
||||||
"sqlint64": "bigint",
|
"sqlint64": "bigint",
|
||||||
"sqlint32": "integer",
|
"sqlint32": "integer",
|
||||||
"sqlint16": "integer",
|
"sqlint16": "smallint",
|
||||||
"sqlbool": "boolean",
|
"sqlbool": "boolean",
|
||||||
"sqlstring": "string",
|
"sqlstring": "text",
|
||||||
"nullablejsonb": "jsonb",
|
"nullablejsonb": "jsonb",
|
||||||
"nullablejson": "json",
|
"nullablejson": "json",
|
||||||
"nullableuuid": "uuid",
|
"nullableuuid": "uuid",
|
||||||
@@ -67,7 +67,7 @@ var GoToPGSQLTypes = map[string]string{
|
|||||||
"float32": "real",
|
"float32": "real",
|
||||||
"complex64": "double precision",
|
"complex64": "double precision",
|
||||||
"complex128": "double precision",
|
"complex128": "double precision",
|
||||||
"customfloat64": "double precisio",
|
"customfloat64": "double precision",
|
||||||
"string": "text",
|
"string": "text",
|
||||||
"Pointer": "bigint",
|
"Pointer": "bigint",
|
||||||
"[]byte": "bytea",
|
"[]byte": "bytea",
|
||||||
@@ -81,9 +81,9 @@ var GoToPGSQLTypes = map[string]string{
|
|||||||
"sqljson": "json",
|
"sqljson": "json",
|
||||||
"sqlint64": "bigint",
|
"sqlint64": "bigint",
|
||||||
"sqlint32": "integer",
|
"sqlint32": "integer",
|
||||||
"sqlint16": "integer",
|
"sqlint16": "smallint",
|
||||||
"sqlbool": "boolean",
|
"sqlbool": "boolean",
|
||||||
"sqlstring": "string",
|
"sqlstring": "text",
|
||||||
"nullablejsonb": "jsonb",
|
"nullablejsonb": "jsonb",
|
||||||
"nullablejson": "json",
|
"nullablejson": "json",
|
||||||
"nullableuuid": "uuid",
|
"nullableuuid": "uuid",
|
||||||
|
|||||||
@@ -128,6 +128,46 @@ func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// splitIdentifier splits a dotted identifier while respecting quotes
|
||||||
|
// Handles cases like: "schema.with.dots"."table"."column"
|
||||||
|
func splitIdentifier(s string) []string {
|
||||||
|
var parts []string
|
||||||
|
var current strings.Builder
|
||||||
|
inQuote := false
|
||||||
|
quoteChar := byte(0)
|
||||||
|
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
ch := s[i]
|
||||||
|
|
||||||
|
if !inQuote {
|
||||||
|
switch ch {
|
||||||
|
case '"', '\'':
|
||||||
|
inQuote = true
|
||||||
|
quoteChar = ch
|
||||||
|
current.WriteByte(ch)
|
||||||
|
case '.':
|
||||||
|
if current.Len() > 0 {
|
||||||
|
parts = append(parts, current.String())
|
||||||
|
current.Reset()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
current.WriteByte(ch)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
current.WriteByte(ch)
|
||||||
|
if ch == quoteChar {
|
||||||
|
inQuote = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if current.Len() > 0 {
|
||||||
|
parts = append(parts, current.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
// stripQuotes removes surrounding quotes and comments from an identifier
|
// stripQuotes removes surrounding quotes and comments from an identifier
|
||||||
func stripQuotes(s string) string {
|
func stripQuotes(s string) string {
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
@@ -409,7 +449,9 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
|||||||
// Parse Table definition
|
// Parse Table definition
|
||||||
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
|
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
|
||||||
tableName := matches[1]
|
tableName := matches[1]
|
||||||
parts := strings.Split(tableName, ".")
|
// Strip comments/notes before parsing to avoid dots in notes
|
||||||
|
tableName = strings.TrimSpace(regexp.MustCompile(`\s*\[.*?\]\s*`).ReplaceAllString(tableName, ""))
|
||||||
|
parts := splitIdentifier(tableName)
|
||||||
|
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
currentSchema = stripQuotes(parts[0])
|
currentSchema = stripQuotes(parts[0])
|
||||||
@@ -561,8 +603,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
|||||||
column.Default = strings.Trim(defaultVal, "'\"")
|
column.Default = strings.Trim(defaultVal, "'\"")
|
||||||
} else if attr == "unique" {
|
} else if attr == "unique" {
|
||||||
// Create a unique constraint
|
// Create a unique constraint
|
||||||
|
// Clean table name by removing leading underscores to avoid double underscores
|
||||||
|
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||||
uniqueConstraint := models.InitConstraint(
|
uniqueConstraint := models.InitConstraint(
|
||||||
fmt.Sprintf("uq_%s", columnName),
|
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||||
models.UniqueConstraint,
|
models.UniqueConstraint,
|
||||||
)
|
)
|
||||||
uniqueConstraint.Schema = schemaName
|
uniqueConstraint.Schema = schemaName
|
||||||
@@ -610,8 +654,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
|||||||
constraint.Table = tableName
|
constraint.Table = tableName
|
||||||
constraint.Columns = []string{columnName}
|
constraint.Columns = []string{columnName}
|
||||||
}
|
}
|
||||||
// Generate short constraint name based on the column
|
// Generate constraint name based on table and columns
|
||||||
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
|
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -695,7 +739,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
|||||||
|
|
||||||
// Generate name if not provided
|
// Generate name if not provided
|
||||||
if index.Name == "" {
|
if index.Name == "" {
|
||||||
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
|
prefix := "idx"
|
||||||
|
if index.Unique {
|
||||||
|
prefix = "uidx"
|
||||||
|
}
|
||||||
|
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return index
|
return index
|
||||||
@@ -755,10 +803,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate short constraint name based on the source column
|
// Generate constraint name based on table and columns
|
||||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
|
||||||
if len(fromColumns) > 0 {
|
if len(fromColumns) == 0 {
|
||||||
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
|
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
constraint := models.InitConstraint(
|
constraint := models.InitConstraint(
|
||||||
@@ -814,7 +862,7 @@ func (r *Reader) parseTableRef(ref string) (schema, table string, columns []stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse schema, table, and optionally column
|
// Parse schema, table, and optionally column
|
||||||
parts := strings.Split(strings.TrimSpace(ref), ".")
|
parts := splitIdentifier(strings.TrimSpace(ref))
|
||||||
if len(parts) == 3 {
|
if len(parts) == 3 {
|
||||||
// Format: "schema"."table"."column"
|
// Format: "schema"."table"."column"
|
||||||
schema = stripQuotes(parts[0])
|
schema = stripQuotes(parts[0])
|
||||||
|
|||||||
@@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConstraintNaming(t *testing.T) {
|
||||||
|
// Test that constraints are named with proper prefixes
|
||||||
|
opts := &readers.ReaderOptions{
|
||||||
|
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(opts)
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find users table
|
||||||
|
var usersTable *models.Table
|
||||||
|
var postsTable *models.Table
|
||||||
|
for _, schema := range db.Schemas {
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
if table.Name == "users" {
|
||||||
|
usersTable = table
|
||||||
|
} else if table.Name == "posts" {
|
||||||
|
postsTable = table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usersTable == nil {
|
||||||
|
t.Fatal("Users table not found")
|
||||||
|
}
|
||||||
|
if postsTable == nil {
|
||||||
|
t.Fatal("Posts table not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test unique constraint naming: ukey_table_column
|
||||||
|
if _, exists := usersTable.Constraints["ukey_users_email"]; !exists {
|
||||||
|
t.Error("Expected unique constraint 'ukey_users_email' not found")
|
||||||
|
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := postsTable.Constraints["ukey_posts_slug"]; !exists {
|
||||||
|
t.Error("Expected unique constraint 'ukey_posts_slug' not found")
|
||||||
|
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test foreign key naming: fk_table_column
|
||||||
|
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
|
||||||
|
t.Error("Expected foreign key 'fk_posts_user_id' not found")
|
||||||
|
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test unique index naming: uidx_table_columns
|
||||||
|
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
|
||||||
|
t.Error("Expected unique index 'uidx_posts_slug' not found")
|
||||||
|
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test regular index naming: idx_table_columns
|
||||||
|
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
|
||||||
|
t.Error("Expected index 'idx_posts_user_id_published' not found")
|
||||||
|
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeys[V any](m map[string]V) []string {
|
||||||
|
keys := make([]string, 0, len(m))
|
||||||
|
for k := range m {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
func TestHasCommentedRefs(t *testing.T) {
|
func TestHasCommentedRefs(t *testing.T) {
|
||||||
// Test with the actual multifile test fixtures
|
// Test with the actual multifile test fixtures
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ fmt.Printf("Found %d scripts\n", len(schema.Scripts))
|
|||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Recursive Directory Scanning**: Automatically scans all subdirectories
|
- **Recursive Directory Scanning**: Automatically scans all subdirectories
|
||||||
|
- **Symlink Skipping**: Symbolic links are automatically skipped (prevents loops and duplicates)
|
||||||
- **Multiple Extensions**: Supports both `.sql` and `.pgsql` files
|
- **Multiple Extensions**: Supports both `.sql` and `.pgsql` files
|
||||||
- **Flexible Naming**: Extract metadata from filename patterns
|
- **Flexible Naming**: Extract metadata from filename patterns
|
||||||
- **Error Handling**: Validates directory existence and file accessibility
|
- **Error Handling**: Validates directory existence and file accessibility
|
||||||
@@ -153,8 +154,9 @@ go test ./pkg/readers/sqldir/
|
|||||||
```
|
```
|
||||||
|
|
||||||
Tests include:
|
Tests include:
|
||||||
- Valid file parsing
|
- Valid file parsing (underscore and hyphen formats)
|
||||||
- Recursive directory scanning
|
- Recursive directory scanning
|
||||||
|
- Symlink skipping
|
||||||
- Invalid filename handling
|
- Invalid filename handling
|
||||||
- Empty directory handling
|
- Empty directory handling
|
||||||
- Error conditions
|
- Error conditions
|
||||||
|
|||||||
@@ -107,11 +107,20 @@ func (r *Reader) readScripts() ([]*models.Script, error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip directories
|
// Don't process directories as files (WalkDir still descends into them recursively)
|
||||||
if d.IsDir() {
|
if d.IsDir() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip symlinks
|
||||||
|
info, err := d.Info()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info.Mode()&os.ModeSymlink != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get filename
|
// Get filename
|
||||||
filename := d.Name()
|
filename := d.Name()
|
||||||
|
|
||||||
|
|||||||
@@ -373,3 +373,65 @@ func TestReader_MixedFormat(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReader_SkipSymlinks(t *testing.T) {
|
||||||
|
// Create temporary test directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "sqldir-test-symlink-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create a real SQL file
|
||||||
|
realFile := filepath.Join(tempDir, "1_001_real_file.sql")
|
||||||
|
if err := os.WriteFile(realFile, []byte("SELECT 1;"), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to create real file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create another file to link to
|
||||||
|
targetFile := filepath.Join(tempDir, "2_001_target.sql")
|
||||||
|
if err := os.WriteFile(targetFile, []byte("SELECT 2;"), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to create target file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a symlink to the target file (this should be skipped)
|
||||||
|
symlinkFile := filepath.Join(tempDir, "3_001_symlink.sql")
|
||||||
|
if err := os.Symlink(targetFile, symlinkFile); err != nil {
|
||||||
|
// Skip test on systems that don't support symlinks (e.g., Windows without admin)
|
||||||
|
t.Skipf("Symlink creation not supported: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create reader
|
||||||
|
reader := NewReader(&readers.ReaderOptions{
|
||||||
|
FilePath: tempDir,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Read database
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := db.Schemas[0]
|
||||||
|
|
||||||
|
// Should only have 2 scripts (real_file and target), symlink should be skipped
|
||||||
|
if len(schema.Scripts) != 2 {
|
||||||
|
t.Errorf("Expected 2 scripts (symlink should be skipped), got %d", len(schema.Scripts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the scripts are the real files, not the symlink
|
||||||
|
scriptNames := make(map[string]bool)
|
||||||
|
for _, script := range schema.Scripts {
|
||||||
|
scriptNames[script.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scriptNames["real_file"] {
|
||||||
|
t.Error("Expected 'real_file' script to be present")
|
||||||
|
}
|
||||||
|
if !scriptNames["target"] {
|
||||||
|
t.Error("Expected 'target' script to be present")
|
||||||
|
}
|
||||||
|
if scriptNames["symlink"] {
|
||||||
|
t.Error("Symlink script should have been skipped but was found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
217
pkg/writers/pgsql/NAMING_CONVENTIONS.md
Normal file
217
pkg/writers/pgsql/NAMING_CONVENTIONS.md
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
# PostgreSQL Naming Conventions
|
||||||
|
|
||||||
|
Standardized naming rules for all database objects in RelSpec PostgreSQL output.
|
||||||
|
|
||||||
|
## Quick Reference
|
||||||
|
|
||||||
|
| Object Type | Prefix | Format | Example |
|
||||||
|
| ----------------- | ----------- | ---------------------------------- | ------------------------ |
|
||||||
|
| Primary Key | `pk_` | `pk_<schema>_<table>` | `pk_public_users` |
|
||||||
|
| Foreign Key | `fk_` | `fk_<table>_<referenced_table>` | `fk_posts_users` |
|
||||||
|
| Unique Constraint | `uk_` | `uk_<table>_<column>` | `uk_users_email` |
|
||||||
|
| Unique Index | `uidx_` | `uidx_<table>_<column>` | `uidx_users_email` |
|
||||||
|
| Regular Index | `idx_` | `idx_<table>_<column>` | `idx_posts_user_id` |
|
||||||
|
| Check Constraint | `chk_` | `chk_<table>_<constraint_purpose>` | `chk_users_age_positive` |
|
||||||
|
| Sequence | `identity_` | `identity_<table>_<column>` | `identity_users_id` |
|
||||||
|
| Trigger | `t_` | `t_<purpose>_<table>` | `t_audit_users` |
|
||||||
|
| Trigger Function | `tf_` | `tf_<purpose>_<table>` | `tf_audit_users` |
|
||||||
|
|
||||||
|
## Naming Rules by Object Type
|
||||||
|
|
||||||
|
### Primary Keys
|
||||||
|
|
||||||
|
**Pattern:** `pk_<schema>_<table>`
|
||||||
|
|
||||||
|
- Include schema name to avoid collisions across schemas
|
||||||
|
- Use lowercase, snake_case format
|
||||||
|
- Examples:
|
||||||
|
- `pk_public_users`
|
||||||
|
- `pk_audit_audit_log`
|
||||||
|
- `pk_staging_temp_data`
|
||||||
|
|
||||||
|
### Foreign Keys
|
||||||
|
|
||||||
|
**Pattern:** `fk_<table>_<referenced_table>`
|
||||||
|
|
||||||
|
- Reference the table containing the FK followed by the referenced table
|
||||||
|
- Use lowercase, snake_case format
|
||||||
|
- Do NOT include column names in standard FK constraints
|
||||||
|
- Examples:
|
||||||
|
- `fk_posts_users` (posts.user_id → users.id)
|
||||||
|
- `fk_comments_posts` (comments.post_id → posts.id)
|
||||||
|
- `fk_order_items_orders` (order_items.order_id → orders.id)
|
||||||
|
|
||||||
|
### Unique Constraints
|
||||||
|
|
||||||
|
**Pattern:** `uk_<table>_<column>`
|
||||||
|
|
||||||
|
- Use `uk_` prefix strictly for database constraints (CONSTRAINT type)
|
||||||
|
- Include column name for clarity
|
||||||
|
- Examples:
|
||||||
|
- `uk_users_email`
|
||||||
|
- `uk_users_username`
|
||||||
|
- `uk_products_sku`
|
||||||
|
|
||||||
|
### Unique Indexes
|
||||||
|
|
||||||
|
**Pattern:** `uidx_<table>_<column>`
|
||||||
|
|
||||||
|
- Use `uidx_` prefix strictly for index type objects
|
||||||
|
- Distinguished from constraints for clarity and implementation flexibility
|
||||||
|
- Examples:
|
||||||
|
- `uidx_users_email`
|
||||||
|
- `uidx_sessions_token`
|
||||||
|
- `uidx_api_keys_key`
|
||||||
|
|
||||||
|
### Regular Indexes
|
||||||
|
|
||||||
|
**Pattern:** `idx_<table>_<column>`
|
||||||
|
|
||||||
|
- Standard indexes for query optimization
|
||||||
|
- Single column: `idx_<table>_<column>`
|
||||||
|
- Examples:
|
||||||
|
- `idx_posts_user_id`
|
||||||
|
- `idx_orders_created_at`
|
||||||
|
- `idx_users_status`
|
||||||
|
|
||||||
|
### Check Constraints
|
||||||
|
|
||||||
|
**Pattern:** `chk_<table>_<constraint_purpose>`
|
||||||
|
|
||||||
|
- Describe the constraint validation purpose
|
||||||
|
- Use lowercase, snake_case for the purpose
|
||||||
|
- Examples:
|
||||||
|
- `chk_users_age_positive` (CHECK (age > 0))
|
||||||
|
- `chk_orders_quantity_positive` (CHECK (quantity > 0))
|
||||||
|
- `chk_products_price_valid` (CHECK (price >= 0))
|
||||||
|
- `chk_users_status_enum` (CHECK (status IN ('active', 'inactive')))
|
||||||
|
|
||||||
|
### Sequences
|
||||||
|
|
||||||
|
**Pattern:** `identity_<table>_<column>`
|
||||||
|
|
||||||
|
- Used for SERIAL/IDENTITY columns
|
||||||
|
- Explicitly named for clarity and management
|
||||||
|
- Examples:
|
||||||
|
- `identity_users_id`
|
||||||
|
- `identity_posts_id`
|
||||||
|
- `identity_transactions_id`
|
||||||
|
|
||||||
|
### Triggers
|
||||||
|
|
||||||
|
**Pattern:** `t_<purpose>_<table>`
|
||||||
|
|
||||||
|
- Include purpose before table name
|
||||||
|
- Lowercase, snake_case format
|
||||||
|
- Examples:
|
||||||
|
- `t_audit_users` (audit trigger on users table)
|
||||||
|
- `t_update_timestamp_posts` (timestamp update trigger on posts)
|
||||||
|
- `t_validate_orders` (validation trigger on orders)
|
||||||
|
|
||||||
|
### Trigger Functions
|
||||||
|
|
||||||
|
**Pattern:** `tf_<purpose>_<table>`
|
||||||
|
|
||||||
|
- Pair with trigger naming convention
|
||||||
|
- Use `tf_` prefix to distinguish from triggers themselves
|
||||||
|
- Examples:
|
||||||
|
- `tf_audit_users` (function for t_audit_users)
|
||||||
|
- `tf_update_timestamp_posts` (function for t_update_timestamp_posts)
|
||||||
|
- `tf_validate_orders` (function for t_validate_orders)
|
||||||
|
|
||||||
|
## Multi-Column Objects
|
||||||
|
|
||||||
|
### Composite Primary Keys
|
||||||
|
|
||||||
|
**Pattern:** `pk_<schema>_<table>`
|
||||||
|
|
||||||
|
- Same as single-column PKs
|
||||||
|
- Example: `pk_public_order_items` (composite key on order_id + item_id)
|
||||||
|
|
||||||
|
### Composite Unique Constraints
|
||||||
|
|
||||||
|
**Pattern:** `uk_<table>_<column1>_<column2>_[...]`
|
||||||
|
|
||||||
|
- Append all column names in order
|
||||||
|
- Examples:
|
||||||
|
- `uk_users_email_domain` (UNIQUE(email, domain))
|
||||||
|
- `uk_inventory_warehouse_sku` (UNIQUE(warehouse_id, sku))
|
||||||
|
|
||||||
|
### Composite Unique Indexes
|
||||||
|
|
||||||
|
**Pattern:** `uidx_<table>_<column1>_<column2>_[...]`
|
||||||
|
|
||||||
|
- Append all column names in order
|
||||||
|
- Examples:
|
||||||
|
- `uidx_users_first_name_last_name` (UNIQUE INDEX on first_name, last_name)
|
||||||
|
- `uidx_sessions_user_id_device_id` (UNIQUE INDEX on user_id, device_id)
|
||||||
|
|
||||||
|
### Composite Regular Indexes
|
||||||
|
|
||||||
|
**Pattern:** `idx_<table>_<column1>_<column2>_[...]`
|
||||||
|
|
||||||
|
- Append all column names in order
|
||||||
|
- List columns in typical query filter order
|
||||||
|
- Examples:
|
||||||
|
- `idx_orders_user_id_created_at` (filter by user, then sort by created_at)
|
||||||
|
- `idx_logs_level_timestamp` (filter by level, then by timestamp)
|
||||||
|
|
||||||
|
## Special Cases & Conventions
|
||||||
|
|
||||||
|
### Audit Trail Tables
|
||||||
|
|
||||||
|
- Audit table naming: `<original_table>_audit` or `audit_<original_table>`
|
||||||
|
- Audit indexes follow standard pattern: `idx_<audit_table>_<column>`
|
||||||
|
- Examples:
|
||||||
|
- Users table audit: `users_audit` with `idx_users_audit_tablename`, `idx_users_audit_changedate`
|
||||||
|
- Posts table audit: `posts_audit` with `idx_posts_audit_tablename`, `idx_posts_audit_changedate`
|
||||||
|
|
||||||
|
### Temporal/Versioning Tables
|
||||||
|
|
||||||
|
- Use suffix `_history` or `_versions` if needed
|
||||||
|
- Apply standard naming rules with the full table name
|
||||||
|
- Examples:
|
||||||
|
- `idx_users_history_user_id`
|
||||||
|
- `uk_posts_versions_version_number`
|
||||||
|
|
||||||
|
### Schema-Specific Objects
|
||||||
|
|
||||||
|
- Always qualify with schema when needed: `pk_<schema>_<table>`
|
||||||
|
- Multiple schemas allowed: `pk_public_users`, `pk_staging_users`
|
||||||
|
|
||||||
|
### Reserved Words & Special Names
|
||||||
|
|
||||||
|
- Avoid PostgreSQL reserved keywords in object names
|
||||||
|
- If column/table names conflict, use quoted identifiers in DDL
|
||||||
|
- Naming convention rules still apply to the logical name
|
||||||
|
|
||||||
|
### Generated/Anonymous Indexes
|
||||||
|
|
||||||
|
- If an index lacks explicit naming, default to: `idx_<schema>_<table>`
|
||||||
|
- Should be replaced with explicit names following standards
|
||||||
|
- Examples (to be renamed):
|
||||||
|
- `idx_public_users` → should be `idx_users_<column>`
|
||||||
|
|
||||||
|
## Implementation Notes
|
||||||
|
|
||||||
|
### Code Generation
|
||||||
|
|
||||||
|
- Names are always lowercase in generated SQL
|
||||||
|
- Underscore separators are required
|
||||||
|
|
||||||
|
### Migration Safety
|
||||||
|
|
||||||
|
- Do NOT rename objects after creation without explicit migration
|
||||||
|
- Names should be consistent across all schema versions
|
||||||
|
- Test generated DDL against PostgreSQL before deployment
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
- Ensure consistency across all table and constraint generation
|
||||||
|
- Test with reserved words to verify escaping
|
||||||
|
|
||||||
|
## Related Documentation
|
||||||
|
|
||||||
|
- PostgreSQL Identifier Rules: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-IDENTIFIERS
|
||||||
|
- Constraint Documentation: https://www.postgresql.org/docs/current/ddl-constraints.html
|
||||||
|
- Index Documentation: https://www.postgresql.org/docs/current/indexes.html
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -335,7 +336,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
ColumnType: modelCol.Type,
|
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
|
||||||
Default: defaultVal,
|
Default: defaultVal,
|
||||||
NotNull: modelCol.NotNull,
|
NotNull: modelCol.NotNull,
|
||||||
})
|
})
|
||||||
@@ -359,7 +360,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
|||||||
SchemaName: schema.Name,
|
SchemaName: schema.Name,
|
||||||
TableName: modelTable.Name,
|
TableName: modelTable.Name,
|
||||||
ColumnName: modelCol.Name,
|
ColumnName: modelCol.Name,
|
||||||
NewType: modelCol.Type,
|
NewType: pgsql.ConvertSQLType(modelCol.Type),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -476,7 +477,7 @@ func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *mo
|
|||||||
}
|
}
|
||||||
if len(pkColumns) > 0 {
|
if len(pkColumns) > 0 {
|
||||||
sort.Strings(pkColumns)
|
sort.Strings(pkColumns)
|
||||||
constraintName := fmt.Sprintf("pk_%s_%s", strings.ToLower(model.Name), strings.ToLower(modelTable.Name))
|
constraintName := fmt.Sprintf("pk_%s_%s", model.SQLName(), modelTable.SQLName())
|
||||||
shouldCreate := true
|
shouldCreate := true
|
||||||
|
|
||||||
if currentTable != nil {
|
if currentTable != nil {
|
||||||
@@ -752,7 +753,7 @@ func (w *MigrationWriter) generateAuditScripts(schema *models.Schema, auditConfi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate audit function
|
// Generate audit function
|
||||||
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
|
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
|
||||||
funcData := BuildAuditFunctionData(schema.Name, table, pk, config, auditSchema, auditConfig.UserFunction)
|
funcData := BuildAuditFunctionData(schema.Name, table, pk, config, auditSchema, auditConfig.UserFunction)
|
||||||
|
|
||||||
funcSQL, err := w.executor.ExecuteAuditFunction(funcData)
|
funcSQL, err := w.executor.ExecuteAuditFunction(funcData)
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ func TestWriteMigration_WithAudit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify audit function
|
// Verify audit function
|
||||||
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
|
if !strings.Contains(output, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
|
||||||
t.Error("Migration missing audit function")
|
t.Error("Migration missing audit function")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
|||||||
|
|
||||||
data := AuditFunctionData{
|
data := AuditFunctionData{
|
||||||
SchemaName: "public",
|
SchemaName: "public",
|
||||||
FunctionName: "ft_audit_users",
|
FunctionName: "tf_audit_users",
|
||||||
TableName: "users",
|
TableName: "users",
|
||||||
TablePrefix: "NULL",
|
TablePrefix: "NULL",
|
||||||
PrimaryKey: "id",
|
PrimaryKey: "id",
|
||||||
@@ -202,7 +202,7 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
|||||||
|
|
||||||
t.Logf("Generated SQL:\n%s", sql)
|
t.Logf("Generated SQL:\n%s", sql)
|
||||||
|
|
||||||
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.ft_audit_users()") {
|
if !strings.Contains(sql, "CREATE OR REPLACE FUNCTION public.tf_audit_users()") {
|
||||||
t.Error("SQL missing function definition")
|
t.Error("SQL missing function definition")
|
||||||
}
|
}
|
||||||
if !strings.Contains(sql, "IF TG_OP = 'INSERT'") {
|
if !strings.Contains(sql, "IF TG_OP = 'INSERT'") {
|
||||||
@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
|||||||
t.Error("SQL missing DELETE handling")
|
t.Error("SQL missing DELETE handling")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
|
||||||
|
// Current database (empty)
|
||||||
|
current := models.InitDatabase("testdb")
|
||||||
|
currentSchema := models.InitSchema("entity")
|
||||||
|
current.Schemas = append(current.Schemas, currentSchema)
|
||||||
|
|
||||||
|
// Model database (with constraint starting with number)
|
||||||
|
model := models.InitDatabase("testdb")
|
||||||
|
modelSchema := models.InitSchema("entity")
|
||||||
|
|
||||||
|
// Create individual_actor_relationship table
|
||||||
|
table := models.InitTable("individual_actor_relationship", "entity")
|
||||||
|
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
|
||||||
|
actorIDCol.Type = "integer"
|
||||||
|
table.Columns["actor_id"] = actorIDCol
|
||||||
|
|
||||||
|
// Add constraint with name starting with number
|
||||||
|
constraint := &models.Constraint{
|
||||||
|
Name: "215162_fk_actor",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"actor_id"},
|
||||||
|
ReferencedSchema: "entity",
|
||||||
|
ReferencedTable: "actor",
|
||||||
|
ReferencedColumns: []string{"id"},
|
||||||
|
OnDelete: "CASCADE",
|
||||||
|
OnUpdate: "NO ACTION",
|
||||||
|
}
|
||||||
|
table.Constraints["215162_fk_actor"] = constraint
|
||||||
|
|
||||||
|
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||||
|
model.Schemas = append(model.Schemas, modelSchema)
|
||||||
|
|
||||||
|
// Generate migration
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create writer: %v", err)
|
||||||
|
}
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err = writer.WriteMigration(model, current)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteMigration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated migration:\n%s", output)
|
||||||
|
|
||||||
|
// Verify constraint name is properly quoted
|
||||||
|
if !strings.Contains(output, `"215162_fk_actor"`) {
|
||||||
|
t.Error("Constraint name starting with number should be quoted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the SQL is syntactically correct (contains required keywords)
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT") {
|
||||||
|
t.Error("Migration missing ADD CONSTRAINT")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "FOREIGN KEY") {
|
||||||
|
t.Error("Migration missing FOREIGN KEY")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
|
|||||||
"quote": quote,
|
"quote": quote,
|
||||||
"escape": escape,
|
"escape": escape,
|
||||||
"safe_identifier": safeIdentifier,
|
"safe_identifier": safeIdentifier,
|
||||||
|
"quote_ident": quoteIdent,
|
||||||
|
|
||||||
// Type conversion
|
// Type conversion
|
||||||
"goTypeToSQL": goTypeToSQL,
|
"goTypeToSQL": goTypeToSQL,
|
||||||
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
|
|||||||
return strings.ToLower(safe)
|
return strings.ToLower(safe)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// quoteIdent quotes a PostgreSQL identifier if necessary
|
||||||
|
// Identifiers need quoting if they:
|
||||||
|
// - Start with a digit
|
||||||
|
// - Contain special characters
|
||||||
|
// - Are reserved keywords
|
||||||
|
// - Contain uppercase letters (to preserve case)
|
||||||
|
func quoteIdent(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return `""`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if quoting is needed
|
||||||
|
needsQuoting := unicode.IsDigit(rune(s[0]))
|
||||||
|
|
||||||
|
// Starts with digit
|
||||||
|
|
||||||
|
// Contains uppercase letters or special characters
|
||||||
|
for _, r := range s {
|
||||||
|
if unicode.IsUpper(r) {
|
||||||
|
needsQuoting = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
|
||||||
|
needsQuoting = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsQuoting {
|
||||||
|
// Escape double quotes by doubling them
|
||||||
|
escaped := strings.ReplaceAll(s, `"`, `""`)
|
||||||
|
return `"` + escaped + `"`
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// Type conversion functions
|
// Type conversion functions
|
||||||
|
|
||||||
// goTypeToSQL converts Go type to PostgreSQL type
|
// goTypeToSQL converts Go type to PostgreSQL type
|
||||||
|
|||||||
@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQuoteIdent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"valid_name", "valid_name"},
|
||||||
|
{"ValidName", `"ValidName"`},
|
||||||
|
{"123column", `"123column"`},
|
||||||
|
{"215162_fk_constraint", `"215162_fk_constraint"`},
|
||||||
|
{"user-id", `"user-id"`},
|
||||||
|
{"user@domain", `"user@domain"`},
|
||||||
|
{`"quoted"`, `"""quoted"""`},
|
||||||
|
{"", `""`},
|
||||||
|
{"lowercase", "lowercase"},
|
||||||
|
{"with_underscore", "with_underscore"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := quoteIdent(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGoTypeToSQL(t *testing.T) {
|
func TestGoTypeToSQL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
|
|||||||
// Check that all expected functions are registered
|
// Check that all expected functions are registered
|
||||||
expectedFuncs := []string{
|
expectedFuncs := []string{
|
||||||
"upper", "lower", "snake_case", "camelCase",
|
"upper", "lower", "snake_case", "camelCase",
|
||||||
"indent", "quote", "escape", "safe_identifier",
|
"indent", "quote", "escape", "safe_identifier", "quote_ident",
|
||||||
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
||||||
"first", "last", "filter", "mapFunc", "join_with",
|
"first", "last", "filter", "mapFunc", "join_with",
|
||||||
"join",
|
"join",
|
||||||
|
|||||||
@@ -177,6 +177,72 @@ type AuditTriggerData struct {
|
|||||||
Events string
|
Events string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateUniqueConstraintData contains data for create unique constraint template
|
||||||
|
type CreateUniqueConstraintData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
Columns string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateCheckConstraintData contains data for create check constraint template
|
||||||
|
type CreateCheckConstraintData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
Expression string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
|
||||||
|
type CreateForeignKeyWithCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
SourceColumns string
|
||||||
|
TargetSchema string
|
||||||
|
TargetTable string
|
||||||
|
TargetColumns string
|
||||||
|
OnDelete string
|
||||||
|
OnUpdate string
|
||||||
|
Deferrable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSequenceValueData contains data for set sequence value template
|
||||||
|
type SetSequenceValueData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
SequenceName string
|
||||||
|
ColumnName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSequenceData contains data for create sequence template
|
||||||
|
type CreateSequenceData struct {
|
||||||
|
SchemaName string
|
||||||
|
SequenceName string
|
||||||
|
Increment int
|
||||||
|
MinValue int64
|
||||||
|
MaxValue int64
|
||||||
|
StartValue int64
|
||||||
|
CacheSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddColumnWithCheckData contains data for add column with existence check template
|
||||||
|
type AddColumnWithCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
ColumnDefinition string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
|
||||||
|
type CreatePrimaryKeyWithAutoGenCheckData struct {
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ConstraintName string
|
||||||
|
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
||||||
|
Columns string
|
||||||
|
}
|
||||||
|
|
||||||
// Execute methods for each template
|
// Execute methods for each template
|
||||||
|
|
||||||
// ExecuteCreateTable executes the create table template
|
// ExecuteCreateTable executes the create table template
|
||||||
@@ -319,6 +385,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
|
|||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateUniqueConstraint executes the create unique constraint template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateCheckConstraint executes the create check constraint template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteSetSequenceValue executes the set sequence value template
|
||||||
|
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreateSequence executes the create sequence template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteAddColumnWithCheck executes the add column with check template
|
||||||
|
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
|
||||||
|
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions to build template data from models
|
// Helper functions to build template data from models
|
||||||
|
|
||||||
// BuildCreateTableData builds CreateTableData from a models.Table
|
// BuildCreateTableData builds CreateTableData from a models.Table
|
||||||
@@ -355,7 +491,7 @@ func BuildAuditFunctionData(
|
|||||||
auditSchema string,
|
auditSchema string,
|
||||||
userFunction string,
|
userFunction string,
|
||||||
) AuditFunctionData {
|
) AuditFunctionData {
|
||||||
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
|
funcName := fmt.Sprintf("tf_audit_%s", table.Name)
|
||||||
|
|
||||||
// Build list of audited columns
|
// Build list of audited columns
|
||||||
auditedColumns := make([]*models.Column, 0)
|
auditedColumns := make([]*models.Column, 0)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
|
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
|
||||||
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
||||||
{{- if .NotNull}} NOT NULL{{end}};
|
{{- if .NotNull}} NOT NULL{{end}};
|
||||||
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND column_name = '{{.ColumnName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD COLUMN {{.ColumnDefinition}};
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
{{- if .SetDefault -}}
|
{{- if .SetDefault -}}
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
|
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||||
{{- else -}}
|
{{- else -}}
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
|
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
|
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
||||||
@@ -1 +1 @@
|
|||||||
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
|
COMMENT ON COLUMN {{quote_ident .SchemaName}}.{{quote_ident .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';
|
||||||
@@ -1 +1 @@
|
|||||||
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
|
COMMENT ON TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} IS '{{.Comment}}';
|
||||||
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||||
|
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
ADD CONSTRAINT {{.ConstraintName}}
|
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||||
FOREIGN KEY ({{.SourceColumns}})
|
FOREIGN KEY ({{.SourceColumns}})
|
||||||
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
|
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||||
ON DELETE {{.OnDelete}}
|
ON DELETE {{.OnDelete}}
|
||||||
ON UPDATE {{.OnUpdate}}
|
ON UPDATE {{.OnUpdate}}
|
||||||
DEFERRABLE;
|
DEFERRABLE;
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
|
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||||
|
FOREIGN KEY ({{.SourceColumns}})
|
||||||
|
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||||
|
ON DELETE {{.OnDelete}}
|
||||||
|
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
|
||||||
|
DEFERRABLE{{end}};
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
|
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
|
||||||
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
|
ON {{quote_ident .SchemaName}}.{{quote_ident .TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||||
@@ -6,8 +6,8 @@ BEGIN
|
|||||||
AND table_name = '{{.TableName}}'
|
AND table_name = '{{.TableName}}'
|
||||||
AND constraint_name = '{{.ConstraintName}}'
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
) THEN
|
) THEN
|
||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
|
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||||
END IF;
|
END IF;
|
||||||
END;
|
END;
|
||||||
$$;
|
$$;
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
auto_pk_name text;
|
||||||
|
BEGIN
|
||||||
|
-- Drop auto-generated primary key if it exists
|
||||||
|
SELECT constraint_name INTO auto_pk_name
|
||||||
|
FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_type = 'PRIMARY KEY'
|
||||||
|
AND constraint_name IN ({{.AutoGenNames}});
|
||||||
|
|
||||||
|
IF auto_pk_name IS NOT NULL THEN
|
||||||
|
EXECUTE 'ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Add named primary key if it doesn't exist
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
CREATE SEQUENCE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}
|
||||||
|
INCREMENT {{.Increment}}
|
||||||
|
MINVALUE {{.MinValue}}
|
||||||
|
MAXVALUE {{.MaxValue}}
|
||||||
|
START {{.StartValue}}
|
||||||
|
CACHE {{.CacheSize}};
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
|
CREATE TABLE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .TableName}} (
|
||||||
{{- range $i, $col := .Columns}}
|
{{- range $i, $col := .Columns}}
|
||||||
{{- if $i}},{{end}}
|
{{- if $i}},{{end}}
|
||||||
{{$col.Name}} {{$col.Type}}
|
{{quote_ident $col.Name}} {{$col.Type}}
|
||||||
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
||||||
{{- if $col.NotNull}} NOT NULL{{end}}
|
{{- if $col.NotNull}} NOT NULL{{end}}
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
|||||||
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = '{{.SchemaName}}'
|
||||||
|
AND table_name = '{{.TableName}}'
|
||||||
|
AND constraint_name = '{{.ConstraintName}}'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -1 +1 @@
|
|||||||
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||||
@@ -1 +1 @@
|
|||||||
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
|
DROP INDEX IF EXISTS {{quote_ident .SchemaName}}.{{quote_ident .IndexName}} CASCADE;
|
||||||
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
m_cnt bigint;
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (
|
||||||
|
SELECT 1 FROM pg_class c
|
||||||
|
INNER JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||||
|
WHERE c.relname = '{{.SequenceName}}'
|
||||||
|
AND n.nspname = '{{.SchemaName}}'
|
||||||
|
AND c.relkind = 'S'
|
||||||
|
) THEN
|
||||||
|
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
|
||||||
|
FROM {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||||
|
INTO m_cnt;
|
||||||
|
|
||||||
|
PERFORM setval('{{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}'::regclass, m_cnt);
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ type Writer struct {
|
|||||||
options *writers.WriterOptions
|
options *writers.WriterOptions
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
executionReport *ExecutionReport
|
executionReport *ExecutionReport
|
||||||
|
executor *TemplateExecutor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecutionReport tracks the execution status of SQL statements
|
// ExecutionReport tracks the execution status of SQL statements
|
||||||
@@ -56,8 +58,10 @@ type ExecutionError struct {
|
|||||||
|
|
||||||
// NewWriter creates a new PostgreSQL SQL writer
|
// NewWriter creates a new PostgreSQL SQL writer
|
||||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||||
|
executor, _ := NewTemplateExecutor()
|
||||||
return &Writer{
|
return &Writer{
|
||||||
options: options,
|
options: options,
|
||||||
|
executor: executor,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,6 +171,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
statements = append(statements, stmts...)
|
statements = append(statements, stmts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Phase 3.5: Add missing columns (for existing tables)
|
||||||
|
addColStmts, err := w.GenerateAddColumnStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate add column statements: %w", err)
|
||||||
|
}
|
||||||
|
statements = append(statements, addColStmts...)
|
||||||
|
|
||||||
// Phase 4: Primary keys
|
// Phase 4: Primary keys
|
||||||
for _, table := range schema.Tables {
|
for _, table := range schema.Tables {
|
||||||
// First check for explicit PrimaryKeyConstraint
|
// First check for explicit PrimaryKeyConstraint
|
||||||
@@ -178,27 +189,50 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var pkColumns []string
|
||||||
|
var pkName string
|
||||||
|
|
||||||
if pkConstraint != nil {
|
if pkConstraint != nil {
|
||||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
pkColumns = pkConstraint.Columns
|
||||||
schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", "))
|
pkName = pkConstraint.Name
|
||||||
statements = append(statements, stmt)
|
|
||||||
} else {
|
} else {
|
||||||
// No explicit constraint, check for columns with IsPrimaryKey = true
|
// No explicit constraint, check for columns with IsPrimaryKey = true
|
||||||
pkColumns := []string{}
|
pkCols := []string{}
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
pkColumns = append(pkColumns, col.SQLName())
|
pkCols = append(pkCols, col.SQLName())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(pkColumns) > 0 {
|
if len(pkCols) > 0 {
|
||||||
// Sort for consistent output
|
// Sort for consistent output
|
||||||
sort.Strings(pkColumns)
|
sort.Strings(pkCols)
|
||||||
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
|
pkColumns = pkCols
|
||||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
pkName = fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
|
||||||
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
|
|
||||||
statements = append(statements, stmt)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(pkColumns) > 0 {
|
||||||
|
// Auto-generated primary key names to check for and drop
|
||||||
|
autoGenPKNames := []string{
|
||||||
|
fmt.Sprintf("%s_pkey", table.Name),
|
||||||
|
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use template to generate primary key statement
|
||||||
|
data := CreatePrimaryKeyWithAutoGenCheckData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ConstraintName: pkName,
|
||||||
|
AutoGenNames: formatStringList(autoGenPKNames),
|
||||||
|
Columns: strings.Join(pkColumns, ", "),
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
|
}
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 5: Indexes
|
// Phase 5: Indexes
|
||||||
@@ -242,7 +276,53 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
||||||
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
uniqueStr, quoteIdentifier(index.Name), schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 5.5: Unique constraints
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
for _, constraint := range table.Constraints {
|
||||||
|
if constraint.Type != models.UniqueConstraint {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use template to generate unique constraint statement
|
||||||
|
data := CreateUniqueConstraintData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ConstraintName: constraint.Name,
|
||||||
|
Columns: strings.Join(constraint.Columns, ", "),
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := w.executor.ExecuteCreateUniqueConstraint(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate unique constraint for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
|
}
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 5.7: Check constraints
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
for _, constraint := range table.Constraints {
|
||||||
|
if constraint.Type != models.CheckConstraint {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use template to generate check constraint statement
|
||||||
|
data := CreateCheckConstraintData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ConstraintName: constraint.Name,
|
||||||
|
Expression: constraint.Expression,
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := w.executor.ExecuteCreateCheckConstraint(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate check constraint for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
|
}
|
||||||
statements = append(statements, stmt)
|
statements = append(statements, stmt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -269,12 +349,24 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
onUpdate = "NO ACTION"
|
onUpdate = "NO ACTION"
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
|
// Use template to generate foreign key statement
|
||||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
data := CreateForeignKeyWithCheckData{
|
||||||
strings.Join(constraint.Columns, ", "),
|
SchemaName: schema.Name,
|
||||||
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
TableName: table.Name,
|
||||||
strings.Join(constraint.ReferencedColumns, ", "),
|
ConstraintName: constraint.Name,
|
||||||
onDelete, onUpdate)
|
SourceColumns: strings.Join(constraint.Columns, ", "),
|
||||||
|
TargetSchema: refSchema,
|
||||||
|
TargetTable: constraint.ReferencedTable,
|
||||||
|
TargetColumns: strings.Join(constraint.ReferencedColumns, ", "),
|
||||||
|
OnDelete: onDelete,
|
||||||
|
OnUpdate: onUpdate,
|
||||||
|
Deferrable: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
|
}
|
||||||
statements = append(statements, stmt)
|
statements = append(statements, stmt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -299,6 +391,67 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
|||||||
return statements, nil
|
return statements, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateAddColumnStatements generates ALTER TABLE ADD COLUMN statements for existing tables
|
||||||
|
// This is useful for schema evolution when new columns are added to existing tables
|
||||||
|
func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, error) {
|
||||||
|
statements := []string{}
|
||||||
|
|
||||||
|
statements = append(statements, fmt.Sprintf("-- Add missing columns for schema: %s", schema.Name))
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
// Sort columns by sequence or name for consistent output
|
||||||
|
columns := make([]*models.Column, 0, len(table.Columns))
|
||||||
|
for _, col := range table.Columns {
|
||||||
|
columns = append(columns, col)
|
||||||
|
}
|
||||||
|
sort.Slice(columns, func(i, j int) bool {
|
||||||
|
if columns[i].Sequence != columns[j].Sequence {
|
||||||
|
return columns[i].Sequence < columns[j].Sequence
|
||||||
|
}
|
||||||
|
return columns[i].Name < columns[j].Name
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, col := range columns {
|
||||||
|
colDef := w.generateColumnDefinition(col)
|
||||||
|
|
||||||
|
// Use template to generate add column statement
|
||||||
|
data := AddColumnWithCheckData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ColumnName: col.Name,
|
||||||
|
ColumnDefinition: colDef,
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := w.executor.ExecuteAddColumnWithCheck(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||||
|
}
|
||||||
|
statements = append(statements, stmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return statements, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
|
||||||
|
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
|
||||||
|
statements := []string{}
|
||||||
|
|
||||||
|
statements = append(statements, "-- Add missing columns to existing tables")
|
||||||
|
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
|
||||||
|
statements = append(statements, "-- Generated by RelSpec")
|
||||||
|
|
||||||
|
for _, schema := range db.Schemas {
|
||||||
|
schemaStatements, err := w.GenerateAddColumnStatements(schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate add column statements for schema %s: %w", schema.Name, err)
|
||||||
|
}
|
||||||
|
statements = append(statements, schemaStatements...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return statements, nil
|
||||||
|
}
|
||||||
|
|
||||||
// generateCreateTableStatement generates CREATE TABLE statement
|
// generateCreateTableStatement generates CREATE TABLE statement
|
||||||
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
||||||
statements := []string{}
|
statements := []string{}
|
||||||
@@ -321,7 +474,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
|||||||
columnDefs = append(columnDefs, " "+def)
|
columnDefs = append(columnDefs, " "+def)
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
|
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n%s\n)",
|
||||||
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
||||||
statements = append(statements, stmt)
|
statements = append(statements, stmt)
|
||||||
|
|
||||||
@@ -332,16 +485,28 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
|||||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||||
parts := []string{col.SQLName()}
|
parts := []string{col.SQLName()}
|
||||||
|
|
||||||
// Type with length/precision
|
// Type with length/precision - convert to valid PostgreSQL type
|
||||||
typeStr := col.Type
|
baseType := pgsql.ConvertSQLType(col.Type)
|
||||||
|
typeStr := baseType
|
||||||
|
|
||||||
|
// Only add size specifiers for types that support them
|
||||||
if col.Length > 0 && col.Precision == 0 {
|
if col.Length > 0 && col.Precision == 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
|
if supportsLength(baseType) {
|
||||||
} else if col.Precision > 0 {
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||||
if col.Scale > 0 {
|
} else if isTextTypeWithoutLength(baseType) {
|
||||||
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
|
// Convert text with length to varchar
|
||||||
} else {
|
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
|
||||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
|
|
||||||
}
|
}
|
||||||
|
// For types that don't support length (integer, bigint, etc.), ignore the length
|
||||||
|
} else if col.Precision > 0 {
|
||||||
|
if supportsPrecision(baseType) {
|
||||||
|
if col.Scale > 0 {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||||
|
} else {
|
||||||
|
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// For types that don't support precision, ignore it
|
||||||
}
|
}
|
||||||
parts = append(parts, typeStr)
|
parts = append(parts, typeStr)
|
||||||
|
|
||||||
@@ -394,6 +559,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Phase 3.5: Add missing columns (priority 120)
|
||||||
|
if err := w.writeAddColumns(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 4: Create primary keys (priority 160)
|
// Phase 4: Create primary keys (priority 160)
|
||||||
if err := w.writePrimaryKeys(schema); err != nil {
|
if err := w.writePrimaryKeys(schema); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -404,6 +574,16 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Phase 5.5: Create unique constraints (priority 185)
|
||||||
|
if err := w.writeUniqueConstraints(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 5.7: Create check constraints (priority 190)
|
||||||
|
if err := w.writeCheckConstraints(schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 6: Create foreign key constraints (priority 195)
|
// Phase 6: Create foreign key constraints (priority 195)
|
||||||
if err := w.writeForeignKeys(schema); err != nil {
|
if err := w.writeForeignKeys(schema); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -435,6 +615,44 @@ func (w *Writer) WriteTable(table *models.Table) error {
|
|||||||
return w.WriteSchema(schema)
|
return w.WriteSchema(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAddColumnStatements writes ALTER TABLE ADD COLUMN statements for a database
|
||||||
|
// This is used for schema evolution/migration when new columns are added
|
||||||
|
func (w *Writer) WriteAddColumnStatements(db *models.Database) error {
|
||||||
|
var writer io.Writer
|
||||||
|
var file *os.File
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Use existing writer if already set (for testing)
|
||||||
|
if w.writer != nil {
|
||||||
|
writer = w.writer
|
||||||
|
} else if w.options.OutputPath != "" {
|
||||||
|
// Determine output destination
|
||||||
|
file, err = os.Create(w.options.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create output file: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
writer = file
|
||||||
|
} else {
|
||||||
|
writer = os.Stdout
|
||||||
|
}
|
||||||
|
|
||||||
|
w.writer = writer
|
||||||
|
|
||||||
|
// Generate statements
|
||||||
|
statements, err := w.GenerateAddColumnsForDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write each statement
|
||||||
|
for _, stmt := range statements {
|
||||||
|
fmt.Fprintf(w.writer, "%s;\n\n", stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writeCreateSchema generates CREATE SCHEMA statement
|
// writeCreateSchema generates CREATE SCHEMA statement
|
||||||
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
||||||
if schema.Name == "public" {
|
if schema.Name == "public" {
|
||||||
@@ -463,13 +681,23 @@ func (w *Writer) writeSequences(schema *models.Schema) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
||||||
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
|
|
||||||
schema.SQLName(), seqName)
|
data := CreateSequenceData{
|
||||||
fmt.Fprintf(w.writer, " INCREMENT 1\n")
|
SchemaName: schema.Name,
|
||||||
fmt.Fprintf(w.writer, " MINVALUE 1\n")
|
SequenceName: seqName,
|
||||||
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
|
Increment: 1,
|
||||||
fmt.Fprintf(w.writer, " START 1\n")
|
MinValue: 1,
|
||||||
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
|
MaxValue: 9223372036854775807,
|
||||||
|
StartValue: 1,
|
||||||
|
CacheSize: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
sql, err := w.executor.ExecuteCreateSequence(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate create sequence for %s.%s: %w", schema.Name, seqName, err)
|
||||||
|
}
|
||||||
|
fmt.Fprint(w.writer, sql)
|
||||||
|
fmt.Fprint(w.writer, "\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -488,15 +716,8 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
|||||||
columnDefs := make([]string, 0, len(columns))
|
columnDefs := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type)
|
// Use generateColumnDefinition to properly handle type, length, precision, and defaults
|
||||||
|
colDef := " " + w.generateColumnDefinition(col)
|
||||||
// Add default value if present
|
|
||||||
if col.Default != nil && col.Default != "" {
|
|
||||||
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
|
||||||
defaultVal := fmt.Sprintf("%v", col.Default)
|
|
||||||
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
|
|
||||||
}
|
|
||||||
|
|
||||||
columnDefs = append(columnDefs, colDef)
|
columnDefs = append(columnDefs, colDef)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -507,6 +728,36 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeAddColumns generates ALTER TABLE ADD COLUMN statements for missing columns
|
||||||
|
func (w *Writer) writeAddColumns(schema *models.Schema) error {
|
||||||
|
fmt.Fprintf(w.writer, "-- Add missing columns for schema: %s\n", schema.Name)
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
// Sort columns by sequence or name for consistent output
|
||||||
|
columns := getSortedColumns(table.Columns)
|
||||||
|
|
||||||
|
for _, col := range columns {
|
||||||
|
colDef := w.generateColumnDefinition(col)
|
||||||
|
|
||||||
|
data := AddColumnWithCheckData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ColumnName: col.Name,
|
||||||
|
ColumnDefinition: colDef,
|
||||||
|
}
|
||||||
|
|
||||||
|
sql, err := w.executor.ExecuteAddColumnWithCheck(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||||
|
}
|
||||||
|
fmt.Fprint(w.writer, sql)
|
||||||
|
fmt.Fprint(w.writer, "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
||||||
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||||
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
||||||
@@ -548,18 +799,26 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
// Auto-generated primary key names to check for and drop
|
||||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
autoGenPKNames := []string{
|
||||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
fmt.Sprintf("%s_pkey", table.Name),
|
||||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
}
|
||||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
|
|
||||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
data := CreatePrimaryKeyWithAutoGenCheckData{
|
||||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
SchemaName: schema.Name,
|
||||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
|
TableName: table.Name,
|
||||||
pkName, strings.Join(columnNames, ", "))
|
ConstraintName: pkName,
|
||||||
fmt.Fprintf(w.writer, " END IF;\n")
|
AutoGenNames: formatStringList(autoGenPKNames),
|
||||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
Columns: strings.Join(columnNames, ", "),
|
||||||
|
}
|
||||||
|
|
||||||
|
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
|
}
|
||||||
|
fmt.Fprint(w.writer, sql)
|
||||||
|
fmt.Fprint(w.writer, "\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -590,9 +849,10 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
|||||||
if indexName == "" {
|
if indexName == "" {
|
||||||
indexType := "idx"
|
indexType := "idx"
|
||||||
if index.Unique {
|
if index.Unique {
|
||||||
indexType = "uk"
|
indexType = "uidx"
|
||||||
}
|
}
|
||||||
indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName())
|
columnSuffix := strings.Join(index.Columns, "_")
|
||||||
|
indexName = fmt.Sprintf("%s_%s_%s", indexType, table.SQLName(), strings.ToLower(columnSuffix))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build column list with operator class support for GIN indexes
|
// Build column list with operator class support for GIN indexes
|
||||||
@@ -641,6 +901,91 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeUniqueConstraints generates ALTER TABLE statements for unique constraints
|
||||||
|
func (w *Writer) writeUniqueConstraints(schema *models.Schema) error {
|
||||||
|
fmt.Fprintf(w.writer, "-- Unique constraints for schema: %s\n", schema.Name)
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
// Sort constraints by name for consistent output
|
||||||
|
constraintNames := make([]string, 0, len(table.Constraints))
|
||||||
|
for name, constraint := range table.Constraints {
|
||||||
|
if constraint.Type == models.UniqueConstraint {
|
||||||
|
constraintNames = append(constraintNames, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(constraintNames)
|
||||||
|
|
||||||
|
for _, name := range constraintNames {
|
||||||
|
constraint := table.Constraints[name]
|
||||||
|
|
||||||
|
// Build column list
|
||||||
|
columnExprs := make([]string, 0, len(constraint.Columns))
|
||||||
|
for _, colName := range constraint.Columns {
|
||||||
|
if col, ok := table.Columns[colName]; ok {
|
||||||
|
columnExprs = append(columnExprs, col.SQLName())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columnExprs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sql, err := w.executor.ExecuteCreateUniqueConstraint(CreateUniqueConstraintData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ConstraintName: constraint.Name,
|
||||||
|
Columns: strings.Join(columnExprs, ", "),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate unique constraint: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w.writer, "%s\n\n", sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCheckConstraints generates ALTER TABLE statements for check constraints
|
||||||
|
func (w *Writer) writeCheckConstraints(schema *models.Schema) error {
|
||||||
|
fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name)
|
||||||
|
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
// Sort constraints by name for consistent output
|
||||||
|
constraintNames := make([]string, 0, len(table.Constraints))
|
||||||
|
for name, constraint := range table.Constraints {
|
||||||
|
if constraint.Type == models.CheckConstraint {
|
||||||
|
constraintNames = append(constraintNames, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(constraintNames)
|
||||||
|
|
||||||
|
for _, name := range constraintNames {
|
||||||
|
constraint := table.Constraints[name]
|
||||||
|
|
||||||
|
// Skip if expression is empty
|
||||||
|
if constraint.Expression == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sql, err := w.executor.ExecuteCreateCheckConstraint(CreateCheckConstraintData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ConstraintName: constraint.Name,
|
||||||
|
Expression: constraint.Expression,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate check constraint: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w.writer, "%s\n\n", sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// writeForeignKeys generates ALTER TABLE statements for foreign keys
|
// writeForeignKeys generates ALTER TABLE statements for foreign keys
|
||||||
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
||||||
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
|
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
|
||||||
@@ -708,13 +1053,6 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
|||||||
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
|
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
||||||
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
|
|
||||||
fmt.Fprintf(w.writer, "\n")
|
|
||||||
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
||||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
|
|
||||||
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
|
||||||
|
|
||||||
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
|
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
|
||||||
refSchema := fkConstraint.ReferencedSchema
|
refSchema := fkConstraint.ReferencedSchema
|
||||||
if refSchema == "" {
|
if refSchema == "" {
|
||||||
@@ -725,11 +1063,103 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
|||||||
refTable = rel.ToTable
|
refTable = rel.ToTable
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
// Use template executor to generate foreign key with existence check
|
||||||
refSchema, refTable, strings.Join(targetColumns, ", "))
|
data := CreateForeignKeyWithCheckData{
|
||||||
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
SchemaName: schema.Name,
|
||||||
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
|
TableName: table.Name,
|
||||||
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n")
|
ConstraintName: fkName,
|
||||||
|
SourceColumns: strings.Join(sourceColumns, ", "),
|
||||||
|
TargetSchema: refSchema,
|
||||||
|
TargetTable: refTable,
|
||||||
|
TargetColumns: strings.Join(targetColumns, ", "),
|
||||||
|
OnDelete: onDelete,
|
||||||
|
OnUpdate: onUpdate,
|
||||||
|
Deferrable: true,
|
||||||
|
}
|
||||||
|
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
|
}
|
||||||
|
fmt.Fprint(w.writer, sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also process any foreign key constraints that don't have a relationship
|
||||||
|
processedConstraints := make(map[string]bool)
|
||||||
|
for _, rel := range table.Relationships {
|
||||||
|
fkName := rel.ForeignKey
|
||||||
|
if fkName == "" {
|
||||||
|
fkName = rel.Name
|
||||||
|
}
|
||||||
|
if fkName != "" {
|
||||||
|
processedConstraints[fkName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find unprocessed foreign key constraints
|
||||||
|
constraintNames := make([]string, 0)
|
||||||
|
for name, constraint := range table.Constraints {
|
||||||
|
if constraint.Type == models.ForeignKeyConstraint && !processedConstraints[name] {
|
||||||
|
constraintNames = append(constraintNames, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(constraintNames)
|
||||||
|
|
||||||
|
for _, name := range constraintNames {
|
||||||
|
constraint := table.Constraints[name]
|
||||||
|
|
||||||
|
// Build column lists
|
||||||
|
sourceColumns := make([]string, 0, len(constraint.Columns))
|
||||||
|
for _, colName := range constraint.Columns {
|
||||||
|
if col, ok := table.Columns[colName]; ok {
|
||||||
|
sourceColumns = append(sourceColumns, col.SQLName())
|
||||||
|
} else {
|
||||||
|
sourceColumns = append(sourceColumns, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
targetColumns := make([]string, 0, len(constraint.ReferencedColumns))
|
||||||
|
for _, colName := range constraint.ReferencedColumns {
|
||||||
|
targetColumns = append(targetColumns, strings.ToLower(colName))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
onDelete := "NO ACTION"
|
||||||
|
if constraint.OnDelete != "" {
|
||||||
|
onDelete = strings.ToUpper(constraint.OnDelete)
|
||||||
|
}
|
||||||
|
|
||||||
|
onUpdate := "NO ACTION"
|
||||||
|
if constraint.OnUpdate != "" {
|
||||||
|
onUpdate = strings.ToUpper(constraint.OnUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
refSchema := constraint.ReferencedSchema
|
||||||
|
if refSchema == "" {
|
||||||
|
refSchema = schema.Name
|
||||||
|
}
|
||||||
|
refTable := constraint.ReferencedTable
|
||||||
|
|
||||||
|
// Use template executor to generate foreign key with existence check
|
||||||
|
data := CreateForeignKeyWithCheckData{
|
||||||
|
SchemaName: schema.Name,
|
||||||
|
TableName: table.Name,
|
||||||
|
ConstraintName: constraint.Name,
|
||||||
|
SourceColumns: strings.Join(sourceColumns, ", "),
|
||||||
|
TargetSchema: refSchema,
|
||||||
|
TargetTable: refTable,
|
||||||
|
TargetColumns: strings.Join(targetColumns, ", "),
|
||||||
|
OnDelete: onDelete,
|
||||||
|
OnUpdate: onUpdate,
|
||||||
|
Deferrable: false,
|
||||||
|
}
|
||||||
|
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
|
}
|
||||||
|
fmt.Fprint(w.writer, sql)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -748,26 +1178,19 @@ func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
|
|||||||
|
|
||||||
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
||||||
|
|
||||||
fmt.Fprintf(w.writer, "DO $$\n")
|
// Use template executor to generate set sequence value statement
|
||||||
fmt.Fprintf(w.writer, "DECLARE\n")
|
data := SetSequenceValueData{
|
||||||
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
|
SchemaName: schema.Name,
|
||||||
fmt.Fprintf(w.writer, "BEGIN\n")
|
TableName: table.Name,
|
||||||
fmt.Fprintf(w.writer, " IF EXISTS (\n")
|
SequenceName: seqName,
|
||||||
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
|
ColumnName: pk.Name,
|
||||||
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
|
}
|
||||||
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
|
sql, err := w.executor.ExecuteSetSequenceValue(data)
|
||||||
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
|
if err != nil {
|
||||||
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
|
return fmt.Errorf("failed to generate set sequence value for %s.%s: %w", schema.Name, table.Name, err)
|
||||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
}
|
||||||
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
|
fmt.Fprint(w.writer, sql)
|
||||||
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
|
fmt.Fprint(w.writer, "\n")
|
||||||
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
|
|
||||||
fmt.Fprintf(w.writer, " \n")
|
|
||||||
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
|
|
||||||
schema.SQLName(), seqName)
|
|
||||||
fmt.Fprintf(w.writer, " END IF;\n")
|
|
||||||
fmt.Fprintf(w.writer, "END;\n")
|
|
||||||
fmt.Fprintf(w.writer, "$$;\n\n")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -841,6 +1264,44 @@ func isTextType(colType string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// supportsLength checks if a PostgreSQL type supports length specification
|
||||||
|
func supportsLength(colType string) bool {
|
||||||
|
lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"}
|
||||||
|
lowerType := strings.ToLower(colType)
|
||||||
|
for _, t := range lengthTypes {
|
||||||
|
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// supportsPrecision checks if a PostgreSQL type supports precision/scale specification
|
||||||
|
func supportsPrecision(colType string) bool {
|
||||||
|
precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"}
|
||||||
|
lowerType := strings.ToLower(colType)
|
||||||
|
for _, t := range precisionTypes {
|
||||||
|
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
|
||||||
|
func isTextTypeWithoutLength(colType string) bool {
|
||||||
|
return strings.EqualFold(colType, "text")
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
|
||||||
|
func formatStringList(items []string) string {
|
||||||
|
quoted := make([]string, len(items))
|
||||||
|
for i, item := range items {
|
||||||
|
quoted[i] = fmt.Sprintf("'%s'", escapeQuote(item))
|
||||||
|
}
|
||||||
|
return strings.Join(quoted, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
// extractOperatorClass extracts operator class from index comment/note
|
// extractOperatorClass extracts operator class from index comment/note
|
||||||
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
|
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
|
||||||
func extractOperatorClass(comment string) string {
|
func extractOperatorClass(comment string) string {
|
||||||
@@ -949,7 +1410,8 @@ func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) erro
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements))
|
stmtType := detectStatementType(stmtTrimmed)
|
||||||
|
fmt.Fprintf(os.Stderr, "Executing statement %d/%d [%s]...\n", i+1, len(statements), stmtType)
|
||||||
|
|
||||||
_, execErr := conn.Exec(ctx, stmt)
|
_, execErr := conn.Exec(ctx, stmt)
|
||||||
if execErr != nil {
|
if execErr != nil {
|
||||||
@@ -1083,3 +1545,94 @@ func truncateStatement(stmt string) string {
|
|||||||
func getCurrentTimestamp() string {
|
func getCurrentTimestamp() string {
|
||||||
return time.Now().Format("2006-01-02 15:04:05")
|
return time.Now().Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectStatementType detects the type of SQL statement for logging
|
||||||
|
func detectStatementType(stmt string) string {
|
||||||
|
upperStmt := strings.ToUpper(stmt)
|
||||||
|
|
||||||
|
// Check for DO blocks (used for conditional DDL)
|
||||||
|
if strings.HasPrefix(upperStmt, "DO $$") || strings.HasPrefix(upperStmt, "DO $") {
|
||||||
|
// Look inside the DO block for the actual operation
|
||||||
|
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD CONSTRAINT") {
|
||||||
|
if strings.Contains(upperStmt, "UNIQUE") {
|
||||||
|
return "ADD UNIQUE CONSTRAINT"
|
||||||
|
} else if strings.Contains(upperStmt, "FOREIGN KEY") {
|
||||||
|
return "ADD FOREIGN KEY"
|
||||||
|
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
|
||||||
|
return "ADD PRIMARY KEY"
|
||||||
|
} else if strings.Contains(upperStmt, "CHECK") {
|
||||||
|
return "ADD CHECK CONSTRAINT"
|
||||||
|
}
|
||||||
|
return "ADD CONSTRAINT"
|
||||||
|
}
|
||||||
|
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD COLUMN") {
|
||||||
|
return "ADD COLUMN"
|
||||||
|
}
|
||||||
|
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
|
||||||
|
return "DROP CONSTRAINT"
|
||||||
|
}
|
||||||
|
return "DO BLOCK"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct DDL statements
|
||||||
|
if strings.HasPrefix(upperStmt, "CREATE SCHEMA") {
|
||||||
|
return "CREATE SCHEMA"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "CREATE SEQUENCE") {
|
||||||
|
return "CREATE SEQUENCE"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "CREATE TABLE") {
|
||||||
|
return "CREATE TABLE"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "CREATE INDEX") {
|
||||||
|
return "CREATE INDEX"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "CREATE UNIQUE INDEX") {
|
||||||
|
return "CREATE UNIQUE INDEX"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "ALTER TABLE") {
|
||||||
|
if strings.Contains(upperStmt, "ADD CONSTRAINT") {
|
||||||
|
if strings.Contains(upperStmt, "FOREIGN KEY") {
|
||||||
|
return "ADD FOREIGN KEY"
|
||||||
|
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
|
||||||
|
return "ADD PRIMARY KEY"
|
||||||
|
} else if strings.Contains(upperStmt, "UNIQUE") {
|
||||||
|
return "ADD UNIQUE CONSTRAINT"
|
||||||
|
} else if strings.Contains(upperStmt, "CHECK") {
|
||||||
|
return "ADD CHECK CONSTRAINT"
|
||||||
|
}
|
||||||
|
return "ADD CONSTRAINT"
|
||||||
|
}
|
||||||
|
if strings.Contains(upperStmt, "ADD COLUMN") {
|
||||||
|
return "ADD COLUMN"
|
||||||
|
}
|
||||||
|
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
|
||||||
|
return "DROP CONSTRAINT"
|
||||||
|
}
|
||||||
|
if strings.Contains(upperStmt, "ALTER COLUMN") {
|
||||||
|
return "ALTER COLUMN"
|
||||||
|
}
|
||||||
|
return "ALTER TABLE"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "COMMENT ON TABLE") {
|
||||||
|
return "COMMENT ON TABLE"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "COMMENT ON COLUMN") {
|
||||||
|
return "COMMENT ON COLUMN"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "DROP TABLE") {
|
||||||
|
return "DROP TABLE"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(upperStmt, "DROP INDEX") {
|
||||||
|
return "DROP INDEX"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default
|
||||||
|
return "SQL"
|
||||||
|
}
|
||||||
|
|
||||||
|
// quoteIdentifier wraps an identifier in double quotes if necessary
|
||||||
|
// This is needed for identifiers that start with numbers or contain special characters
|
||||||
|
func quoteIdentifier(s string) string {
|
||||||
|
return quoteIdent(s)
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,11 +45,11 @@ func TestWriteDatabase(t *testing.T) {
|
|||||||
|
|
||||||
// Add unique index
|
// Add unique index
|
||||||
uniqueEmailIndex := &models.Index{
|
uniqueEmailIndex := &models.Index{
|
||||||
Name: "uk_users_email",
|
Name: "uidx_users_email",
|
||||||
Unique: true,
|
Unique: true,
|
||||||
Columns: []string{"email"},
|
Columns: []string{"email"},
|
||||||
}
|
}
|
||||||
table.Indexes["uk_users_email"] = uniqueEmailIndex
|
table.Indexes["uidx_users_email"] = uniqueEmailIndex
|
||||||
|
|
||||||
schema.Tables = append(schema.Tables, table)
|
schema.Tables = append(schema.Tables, table)
|
||||||
db.Schemas = append(db.Schemas, schema)
|
db.Schemas = append(db.Schemas, schema)
|
||||||
@@ -164,6 +164,296 @@ func TestWriteForeignKeys(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteUniqueConstraints(t *testing.T) {
|
||||||
|
// Create a test database with unique constraints
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create table with unique constraints
|
||||||
|
table := models.InitTable("users", "public")
|
||||||
|
|
||||||
|
// Add columns
|
||||||
|
emailCol := models.InitColumn("email", "users", "public")
|
||||||
|
emailCol.Type = "varchar(255)"
|
||||||
|
emailCol.NotNull = true
|
||||||
|
table.Columns["email"] = emailCol
|
||||||
|
|
||||||
|
guidCol := models.InitColumn("guid", "users", "public")
|
||||||
|
guidCol.Type = "uuid"
|
||||||
|
guidCol.NotNull = true
|
||||||
|
table.Columns["guid"] = guidCol
|
||||||
|
|
||||||
|
// Add unique constraints
|
||||||
|
emailConstraint := &models.Constraint{
|
||||||
|
Name: "uq_email",
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "users",
|
||||||
|
Columns: []string{"email"},
|
||||||
|
}
|
||||||
|
table.Constraints["uq_email"] = emailConstraint
|
||||||
|
|
||||||
|
guidConstraint := &models.Constraint{
|
||||||
|
Name: "uq_guid",
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "users",
|
||||||
|
Columns: []string{"guid"},
|
||||||
|
}
|
||||||
|
table.Constraints["uq_guid"] = guidConstraint
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify unique constraints are present
|
||||||
|
if !strings.Contains(output, "-- Unique constraints for schema: public") {
|
||||||
|
t.Errorf("Output missing unique constraints header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT uq_email UNIQUE (email)") {
|
||||||
|
t.Errorf("Output missing uq_email unique constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT uq_guid UNIQUE (guid)") {
|
||||||
|
t.Errorf("Output missing uq_guid unique constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteCheckConstraints(t *testing.T) {
|
||||||
|
// Create a test database with check constraints
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create table with check constraints
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
// Add columns
|
||||||
|
priceCol := models.InitColumn("price", "products", "public")
|
||||||
|
priceCol.Type = "numeric(10,2)"
|
||||||
|
table.Columns["price"] = priceCol
|
||||||
|
|
||||||
|
statusCol := models.InitColumn("status", "products", "public")
|
||||||
|
statusCol.Type = "varchar(20)"
|
||||||
|
table.Columns["status"] = statusCol
|
||||||
|
|
||||||
|
quantityCol := models.InitColumn("quantity", "products", "public")
|
||||||
|
quantityCol.Type = "integer"
|
||||||
|
table.Columns["quantity"] = quantityCol
|
||||||
|
|
||||||
|
// Add check constraints
|
||||||
|
priceConstraint := &models.Constraint{
|
||||||
|
Name: "ck_price_positive",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "products",
|
||||||
|
Expression: "price >= 0",
|
||||||
|
}
|
||||||
|
table.Constraints["ck_price_positive"] = priceConstraint
|
||||||
|
|
||||||
|
statusConstraint := &models.Constraint{
|
||||||
|
Name: "ck_status_valid",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "products",
|
||||||
|
Expression: "status IN ('active', 'inactive', 'discontinued')",
|
||||||
|
}
|
||||||
|
table.Constraints["ck_status_valid"] = statusConstraint
|
||||||
|
|
||||||
|
quantityConstraint := &models.Constraint{
|
||||||
|
Name: "ck_quantity_nonnegative",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "products",
|
||||||
|
Expression: "quantity >= 0",
|
||||||
|
}
|
||||||
|
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify check constraints are present
|
||||||
|
if !strings.Contains(output, "-- Check constraints for schema: public") {
|
||||||
|
t.Errorf("Output missing check constraints header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
|
||||||
|
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
|
||||||
|
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
|
||||||
|
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteAllConstraintTypes(t *testing.T) {
|
||||||
|
// Create a comprehensive test with all constraint types
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create orders table
|
||||||
|
ordersTable := models.InitTable("orders", "public")
|
||||||
|
|
||||||
|
// Add columns
|
||||||
|
idCol := models.InitColumn("id", "orders", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
ordersTable.Columns["id"] = idCol
|
||||||
|
|
||||||
|
userIDCol := models.InitColumn("user_id", "orders", "public")
|
||||||
|
userIDCol.Type = "integer"
|
||||||
|
userIDCol.NotNull = true
|
||||||
|
ordersTable.Columns["user_id"] = userIDCol
|
||||||
|
|
||||||
|
orderNumberCol := models.InitColumn("order_number", "orders", "public")
|
||||||
|
orderNumberCol.Type = "varchar(50)"
|
||||||
|
orderNumberCol.NotNull = true
|
||||||
|
ordersTable.Columns["order_number"] = orderNumberCol
|
||||||
|
|
||||||
|
totalCol := models.InitColumn("total", "orders", "public")
|
||||||
|
totalCol.Type = "numeric(10,2)"
|
||||||
|
ordersTable.Columns["total"] = totalCol
|
||||||
|
|
||||||
|
statusCol := models.InitColumn("status", "orders", "public")
|
||||||
|
statusCol.Type = "varchar(20)"
|
||||||
|
ordersTable.Columns["status"] = statusCol
|
||||||
|
|
||||||
|
// Add primary key constraint
|
||||||
|
pkConstraint := &models.Constraint{
|
||||||
|
Name: "pk_orders",
|
||||||
|
Type: models.PrimaryKeyConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Columns: []string{"id"},
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["pk_orders"] = pkConstraint
|
||||||
|
|
||||||
|
// Add unique constraint
|
||||||
|
uniqueConstraint := &models.Constraint{
|
||||||
|
Name: "uq_order_number",
|
||||||
|
Type: models.UniqueConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Columns: []string{"order_number"},
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
|
||||||
|
|
||||||
|
// Add check constraint
|
||||||
|
checkConstraint := &models.Constraint{
|
||||||
|
Name: "ck_total_positive",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Expression: "total > 0",
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["ck_total_positive"] = checkConstraint
|
||||||
|
|
||||||
|
statusCheckConstraint := &models.Constraint{
|
||||||
|
Name: "ck_status_valid",
|
||||||
|
Type: models.CheckConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Expression: "status IN ('pending', 'completed', 'cancelled')",
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
|
||||||
|
|
||||||
|
// Add foreign key constraint (referencing a users table)
|
||||||
|
fkConstraint := &models.Constraint{
|
||||||
|
Name: "fk_orders_user",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Schema: "public",
|
||||||
|
Table: "orders",
|
||||||
|
Columns: []string{"user_id"},
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
ReferencedTable: "users",
|
||||||
|
ReferencedColumns: []string{"id"},
|
||||||
|
OnDelete: "CASCADE",
|
||||||
|
OnUpdate: "CASCADE",
|
||||||
|
}
|
||||||
|
ordersTable.Constraints["fk_orders_user"] = fkConstraint
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, ordersTable)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify all constraint types are present
|
||||||
|
expectedConstraints := map[string]string{
|
||||||
|
"Primary Key": "PRIMARY KEY",
|
||||||
|
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
||||||
|
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
||||||
|
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
||||||
|
"Foreign Key": "FOREIGN KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, expected := range expectedConstraints {
|
||||||
|
if !strings.Contains(output, expected) {
|
||||||
|
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify section headers
|
||||||
|
sections := []string{
|
||||||
|
"-- Primary keys for schema: public",
|
||||||
|
"-- Unique constraints for schema: public",
|
||||||
|
"-- Check constraints for schema: public",
|
||||||
|
"-- Foreign keys for schema: public",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, section := range sections {
|
||||||
|
if !strings.Contains(output, section) {
|
||||||
|
t.Errorf("Output missing section header: %s", section)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteTable(t *testing.T) {
|
func TestWriteTable(t *testing.T) {
|
||||||
// Create a single table
|
// Create a single table
|
||||||
table := models.InitTable("products", "public")
|
table := models.InitTable("products", "public")
|
||||||
@@ -241,3 +531,327 @@ func TestIsIntegerType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTypeConversion(t *testing.T) {
|
||||||
|
// Test that invalid Go types are converted to valid PostgreSQL types
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a test table with Go types instead of SQL types
|
||||||
|
table := models.InitTable("test_types", "public")
|
||||||
|
|
||||||
|
// Add columns with Go types (invalid for PostgreSQL)
|
||||||
|
stringCol := models.InitColumn("name", "test_types", "public")
|
||||||
|
stringCol.Type = "string" // Should be converted to "text"
|
||||||
|
table.Columns["name"] = stringCol
|
||||||
|
|
||||||
|
int64Col := models.InitColumn("big_id", "test_types", "public")
|
||||||
|
int64Col.Type = "int64" // Should be converted to "bigint"
|
||||||
|
table.Columns["big_id"] = int64Col
|
||||||
|
|
||||||
|
int16Col := models.InitColumn("small_id", "test_types", "public")
|
||||||
|
int16Col.Type = "int16" // Should be converted to "smallint"
|
||||||
|
table.Columns["small_id"] = int16Col
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write the database
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Print output for debugging
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify that Go types were converted to PostgreSQL types
|
||||||
|
if strings.Contains(output, "string") {
|
||||||
|
t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "int64") {
|
||||||
|
t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "int16") {
|
||||||
|
t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify correct PostgreSQL types are present
|
||||||
|
if !strings.Contains(output, "text") {
|
||||||
|
t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "bigint") {
|
||||||
|
t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "smallint") {
|
||||||
|
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrimaryKeyExistenceCheck(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
idCol := models.InitColumn("id", "products", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.IsPrimaryKey = true
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
nameCol := models.InitColumn("name", "products", "public")
|
||||||
|
nameCol.Type = "text"
|
||||||
|
table.Columns["name"] = nameCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify our naming convention is used
|
||||||
|
if !strings.Contains(output, "pk_public_products") {
|
||||||
|
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it drops auto-generated primary keys
|
||||||
|
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
|
||||||
|
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it checks for our specific named constraint before adding it
|
||||||
|
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
|
||||||
|
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestColumnSizeSpecifiers(t *testing.T) {
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
table := models.InitTable("test_sizes", "public")
|
||||||
|
|
||||||
|
// Integer with invalid size specifier - should ignore size
|
||||||
|
integerCol := models.InitColumn("int_col", "test_sizes", "public")
|
||||||
|
integerCol.Type = "integer"
|
||||||
|
integerCol.Length = 32
|
||||||
|
table.Columns["int_col"] = integerCol
|
||||||
|
|
||||||
|
// Bigint with invalid size specifier - should ignore size
|
||||||
|
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
|
||||||
|
bigintCol.Type = "bigint"
|
||||||
|
bigintCol.Length = 64
|
||||||
|
table.Columns["bigint_col"] = bigintCol
|
||||||
|
|
||||||
|
// Smallint with invalid size specifier - should ignore size
|
||||||
|
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
|
||||||
|
smallintCol.Type = "smallint"
|
||||||
|
smallintCol.Length = 16
|
||||||
|
table.Columns["smallint_col"] = smallintCol
|
||||||
|
|
||||||
|
// Text with length - should convert to varchar
|
||||||
|
textCol := models.InitColumn("text_col", "test_sizes", "public")
|
||||||
|
textCol.Type = "text"
|
||||||
|
textCol.Length = 100
|
||||||
|
table.Columns["text_col"] = textCol
|
||||||
|
|
||||||
|
// Varchar with length - should keep varchar with length
|
||||||
|
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
|
||||||
|
varcharCol.Type = "varchar"
|
||||||
|
varcharCol.Length = 50
|
||||||
|
table.Columns["varchar_col"] = varcharCol
|
||||||
|
|
||||||
|
// Decimal with precision and scale - should keep them
|
||||||
|
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
|
||||||
|
decimalCol.Type = "decimal"
|
||||||
|
decimalCol.Precision = 19
|
||||||
|
decimalCol.Scale = 4
|
||||||
|
table.Columns["decimal_col"] = decimalCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated SQL:\n%s", output)
|
||||||
|
|
||||||
|
// Verify invalid size specifiers are NOT present
|
||||||
|
invalidPatterns := []string{
|
||||||
|
"integer(32)",
|
||||||
|
"bigint(64)",
|
||||||
|
"smallint(16)",
|
||||||
|
"text(100)",
|
||||||
|
}
|
||||||
|
for _, pattern := range invalidPatterns {
|
||||||
|
if strings.Contains(output, pattern) {
|
||||||
|
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify valid patterns ARE present
|
||||||
|
validPatterns := []string{
|
||||||
|
"integer", // without size
|
||||||
|
"bigint", // without size
|
||||||
|
"smallint", // without size
|
||||||
|
"varchar(100)", // text converted to varchar with length
|
||||||
|
"varchar(50)", // varchar with length
|
||||||
|
"decimal(19,4)", // decimal with precision and scale
|
||||||
|
}
|
||||||
|
for _, pattern := range validPatterns {
|
||||||
|
if !strings.Contains(output, pattern) {
|
||||||
|
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAddColumnStatements(t *testing.T) {
|
||||||
|
// Create a test database with tables that have new columns
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a table with columns
|
||||||
|
table := models.InitTable("users", "public")
|
||||||
|
|
||||||
|
// Existing column
|
||||||
|
idCol := models.InitColumn("id", "users", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
idCol.NotNull = true
|
||||||
|
idCol.Sequence = 1
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
// New column to be added
|
||||||
|
emailCol := models.InitColumn("email", "users", "public")
|
||||||
|
emailCol.Type = "varchar"
|
||||||
|
emailCol.Length = 255
|
||||||
|
emailCol.NotNull = true
|
||||||
|
emailCol.Sequence = 2
|
||||||
|
table.Columns["email"] = emailCol
|
||||||
|
|
||||||
|
// New column with default
|
||||||
|
statusCol := models.InitColumn("status", "users", "public")
|
||||||
|
statusCol.Type = "text"
|
||||||
|
statusCol.Default = "active"
|
||||||
|
statusCol.Sequence = 3
|
||||||
|
table.Columns["status"] = statusCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
|
||||||
|
// Generate ADD COLUMN statements
|
||||||
|
statements, err := writer.GenerateAddColumnsForDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join all statements to verify content
|
||||||
|
output := strings.Join(statements, "\n")
|
||||||
|
t.Logf("Generated ADD COLUMN statements:\n%s", output)
|
||||||
|
|
||||||
|
// Verify expected elements
|
||||||
|
expectedStrings := []string{
|
||||||
|
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
|
||||||
|
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
|
||||||
|
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
|
||||||
|
"information_schema.columns",
|
||||||
|
"table_schema = 'public'",
|
||||||
|
"table_name = 'users'",
|
||||||
|
"column_name = 'id'",
|
||||||
|
"column_name = 'email'",
|
||||||
|
"column_name = 'status'",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedStrings {
|
||||||
|
if !strings.Contains(output, expected) {
|
||||||
|
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify DO blocks are present for conditional adds
|
||||||
|
doBlockCount := strings.Count(output, "DO $$")
|
||||||
|
if doBlockCount < 3 {
|
||||||
|
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify IF NOT EXISTS logic
|
||||||
|
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
|
||||||
|
if ifNotExistsCount < 3 {
|
||||||
|
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteAddColumnStatements(t *testing.T) {
|
||||||
|
// Create a test database
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("public")
|
||||||
|
|
||||||
|
// Create a table with a new column to be added
|
||||||
|
table := models.InitTable("products", "public")
|
||||||
|
|
||||||
|
idCol := models.InitColumn("id", "products", "public")
|
||||||
|
idCol.Type = "integer"
|
||||||
|
table.Columns["id"] = idCol
|
||||||
|
|
||||||
|
// New column with various properties
|
||||||
|
descCol := models.InitColumn("description", "products", "public")
|
||||||
|
descCol.Type = "text"
|
||||||
|
descCol.NotNull = false
|
||||||
|
table.Columns["description"] = descCol
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer with output to buffer
|
||||||
|
var buf bytes.Buffer
|
||||||
|
options := &writers.WriterOptions{}
|
||||||
|
writer := NewWriter(options)
|
||||||
|
writer.writer = &buf
|
||||||
|
|
||||||
|
// Write ADD COLUMN statements
|
||||||
|
err := writer.WriteAddColumnStatements(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteAddColumnStatements failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
t.Logf("Generated output:\n%s", output)
|
||||||
|
|
||||||
|
// Verify output contains expected elements
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
|
||||||
|
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
|
||||||
|
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "DO $$") {
|
||||||
|
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ The SQL Executor Writer (`sqlexec`) executes SQL scripts from `models.Script` ob
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Ordered Execution**: Scripts execute in Priority→Sequence order
|
- **Ordered Execution**: Scripts execute in Priority→Sequence→Name order
|
||||||
- **PostgreSQL Support**: Uses `pgx/v5` driver for robust PostgreSQL connectivity
|
- **PostgreSQL Support**: Uses `pgx/v5` driver for robust PostgreSQL connectivity
|
||||||
- **Stop on Error**: Execution halts immediately on first error (default behavior)
|
- **Stop on Error**: Execution halts immediately on first error (default behavior)
|
||||||
- **Progress Reporting**: Prints execution status to stdout
|
- **Progress Reporting**: Prints execution status to stdout
|
||||||
@@ -103,19 +103,40 @@ Scripts are sorted and executed based on:
|
|||||||
|
|
||||||
1. **Priority** (ascending): Lower priority values execute first
|
1. **Priority** (ascending): Lower priority values execute first
|
||||||
2. **Sequence** (ascending): Within same priority, lower sequence values execute first
|
2. **Sequence** (ascending): Within same priority, lower sequence values execute first
|
||||||
|
3. **Name** (ascending): Within same priority and sequence, alphabetical order by name
|
||||||
|
|
||||||
### Example Execution Order
|
### Example Execution Order
|
||||||
|
|
||||||
Given these scripts:
|
Given these scripts:
|
||||||
```
|
```
|
||||||
Script A: Priority=2, Sequence=1
|
Script A: Priority=2, Sequence=1, Name="zebra"
|
||||||
Script B: Priority=1, Sequence=3
|
Script B: Priority=1, Sequence=3, Name="script"
|
||||||
Script C: Priority=1, Sequence=1
|
Script C: Priority=1, Sequence=1, Name="apple"
|
||||||
Script D: Priority=1, Sequence=2
|
Script D: Priority=1, Sequence=1, Name="beta"
|
||||||
Script E: Priority=3, Sequence=1
|
Script E: Priority=3, Sequence=1, Name="script"
|
||||||
```
|
```
|
||||||
|
|
||||||
Execution order: **C → D → B → A → E**
|
Execution order: **C (apple) → D (beta) → B → A → E**
|
||||||
|
|
||||||
|
### Directory-based Sorting Example
|
||||||
|
|
||||||
|
Given these files:
|
||||||
|
```
|
||||||
|
1_001_create_schema.sql
|
||||||
|
1_001_create_users.sql ← Alphabetically before "drop_tables"
|
||||||
|
1_001_drop_tables.sql
|
||||||
|
1_002_add_indexes.sql
|
||||||
|
2_001_constraints.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
Execution order (note alphabetical sorting at same priority/sequence):
|
||||||
|
```
|
||||||
|
1_001_create_schema.sql
|
||||||
|
1_001_create_users.sql
|
||||||
|
1_001_drop_tables.sql
|
||||||
|
1_002_add_indexes.sql
|
||||||
|
2_001_constraints.sql
|
||||||
|
```
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,11 @@ func NewWriter(options *writers.WriterOptions) *Writer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Options returns the writer options (useful for reading execution results)
|
||||||
|
func (w *Writer) Options() *writers.WriterOptions {
|
||||||
|
return w.options
|
||||||
|
}
|
||||||
|
|
||||||
// WriteDatabase executes all scripts from all schemas in the database
|
// WriteDatabase executes all scripts from all schemas in the database
|
||||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||||
if db == nil {
|
if db == nil {
|
||||||
@@ -86,20 +91,39 @@ func (w *Writer) WriteTable(table *models.Table) error {
|
|||||||
return fmt.Errorf("WriteTable is not supported for SQL script execution")
|
return fmt.Errorf("WriteTable is not supported for SQL script execution")
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeScripts executes scripts in Priority then Sequence order
|
// executeScripts executes scripts in Priority, Sequence, then Name order
|
||||||
func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*models.Script) error {
|
func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*models.Script) error {
|
||||||
if len(scripts) == 0 {
|
if len(scripts) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort scripts by Priority (ascending) then Sequence (ascending)
|
// Check if we should ignore errors
|
||||||
|
ignoreErrors := false
|
||||||
|
if val, ok := w.options.Metadata["ignore_errors"].(bool); ok {
|
||||||
|
ignoreErrors = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track failed scripts and execution counts
|
||||||
|
var failedScripts []struct {
|
||||||
|
name string
|
||||||
|
priority int
|
||||||
|
sequence uint
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
successCount := 0
|
||||||
|
totalCount := 0
|
||||||
|
|
||||||
|
// Sort scripts by Priority (ascending), Sequence (ascending), then Name (ascending)
|
||||||
sortedScripts := make([]*models.Script, len(scripts))
|
sortedScripts := make([]*models.Script, len(scripts))
|
||||||
copy(sortedScripts, scripts)
|
copy(sortedScripts, scripts)
|
||||||
sort.Slice(sortedScripts, func(i, j int) bool {
|
sort.Slice(sortedScripts, func(i, j int) bool {
|
||||||
if sortedScripts[i].Priority != sortedScripts[j].Priority {
|
if sortedScripts[i].Priority != sortedScripts[j].Priority {
|
||||||
return sortedScripts[i].Priority < sortedScripts[j].Priority
|
return sortedScripts[i].Priority < sortedScripts[j].Priority
|
||||||
}
|
}
|
||||||
|
if sortedScripts[i].Sequence != sortedScripts[j].Sequence {
|
||||||
return sortedScripts[i].Sequence < sortedScripts[j].Sequence
|
return sortedScripts[i].Sequence < sortedScripts[j].Sequence
|
||||||
|
}
|
||||||
|
return sortedScripts[i].Name < sortedScripts[j].Name
|
||||||
})
|
})
|
||||||
|
|
||||||
// Execute each script in order
|
// Execute each script in order
|
||||||
@@ -108,18 +132,49 @@ func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalCount++
|
||||||
fmt.Printf("Executing script: %s (Priority=%d, Sequence=%d)\n",
|
fmt.Printf("Executing script: %s (Priority=%d, Sequence=%d)\n",
|
||||||
script.Name, script.Priority, script.Sequence)
|
script.Name, script.Priority, script.Sequence)
|
||||||
|
|
||||||
// Execute the SQL script
|
// Execute the SQL script
|
||||||
_, err := conn.Exec(ctx, script.SQL)
|
_, err := conn.Exec(ctx, script.SQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute script %s (Priority=%d, Sequence=%d): %w",
|
if ignoreErrors {
|
||||||
|
fmt.Printf("⚠ Error executing %s: %v (continuing due to --ignore-errors)\n", script.Name, err)
|
||||||
|
failedScripts = append(failedScripts, struct {
|
||||||
|
name string
|
||||||
|
priority int
|
||||||
|
sequence uint
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
name: script.Name,
|
||||||
|
priority: script.Priority,
|
||||||
|
sequence: script.Sequence,
|
||||||
|
err: err,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("script %s (Priority=%d, Sequence=%d): %w",
|
||||||
script.Name, script.Priority, script.Sequence, err)
|
script.Name, script.Priority, script.Sequence, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
successCount++
|
||||||
fmt.Printf("✓ Successfully executed: %s\n", script.Name)
|
fmt.Printf("✓ Successfully executed: %s\n", script.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store execution results in metadata for caller
|
||||||
|
w.options.Metadata["execution_total"] = totalCount
|
||||||
|
w.options.Metadata["execution_success"] = successCount
|
||||||
|
w.options.Metadata["execution_failed"] = len(failedScripts)
|
||||||
|
|
||||||
|
// Print summary of failed scripts if any
|
||||||
|
if len(failedScripts) > 0 {
|
||||||
|
fmt.Printf("\n⚠ Failed Scripts Summary (%d failed):\n", len(failedScripts))
|
||||||
|
for i, failed := range failedScripts {
|
||||||
|
fmt.Printf(" %d. %s (Priority=%d, Sequence=%d)\n Error: %v\n",
|
||||||
|
i+1, failed.name, failed.priority, failed.sequence, failed.err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,13 +99,13 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestScriptSorting verifies that scripts are sorted correctly by Priority then Sequence
|
// TestScriptSorting verifies that scripts are sorted correctly by Priority, Sequence, then Name
|
||||||
func TestScriptSorting(t *testing.T) {
|
func TestScriptSorting(t *testing.T) {
|
||||||
scripts := []*models.Script{
|
scripts := []*models.Script{
|
||||||
{Name: "script1", Priority: 2, Sequence: 1, SQL: "SELECT 1;"},
|
{Name: "z_script1", Priority: 2, Sequence: 1, SQL: "SELECT 1;"},
|
||||||
{Name: "script2", Priority: 1, Sequence: 3, SQL: "SELECT 2;"},
|
{Name: "script2", Priority: 1, Sequence: 3, SQL: "SELECT 2;"},
|
||||||
{Name: "script3", Priority: 1, Sequence: 1, SQL: "SELECT 3;"},
|
{Name: "a_script3", Priority: 1, Sequence: 1, SQL: "SELECT 3;"},
|
||||||
{Name: "script4", Priority: 1, Sequence: 2, SQL: "SELECT 4;"},
|
{Name: "b_script4", Priority: 1, Sequence: 1, SQL: "SELECT 4;"},
|
||||||
{Name: "script5", Priority: 3, Sequence: 1, SQL: "SELECT 5;"},
|
{Name: "script5", Priority: 3, Sequence: 1, SQL: "SELECT 5;"},
|
||||||
{Name: "script6", Priority: 2, Sequence: 2, SQL: "SELECT 6;"},
|
{Name: "script6", Priority: 2, Sequence: 2, SQL: "SELECT 6;"},
|
||||||
}
|
}
|
||||||
@@ -114,23 +114,33 @@ func TestScriptSorting(t *testing.T) {
|
|||||||
sortedScripts := make([]*models.Script, len(scripts))
|
sortedScripts := make([]*models.Script, len(scripts))
|
||||||
copy(sortedScripts, scripts)
|
copy(sortedScripts, scripts)
|
||||||
|
|
||||||
// Use the same sorting logic from executeScripts
|
// Sort by Priority, Sequence, then Name (matching executeScripts logic)
|
||||||
for i := 0; i < len(sortedScripts)-1; i++ {
|
for i := 0; i < len(sortedScripts)-1; i++ {
|
||||||
for j := i + 1; j < len(sortedScripts); j++ {
|
for j := i + 1; j < len(sortedScripts); j++ {
|
||||||
if sortedScripts[i].Priority > sortedScripts[j].Priority ||
|
si, sj := sortedScripts[i], sortedScripts[j]
|
||||||
(sortedScripts[i].Priority == sortedScripts[j].Priority &&
|
// Compare by priority first
|
||||||
sortedScripts[i].Sequence > sortedScripts[j].Sequence) {
|
if si.Priority > sj.Priority {
|
||||||
|
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
||||||
|
} else if si.Priority == sj.Priority {
|
||||||
|
// If same priority, compare by sequence
|
||||||
|
if si.Sequence > sj.Sequence {
|
||||||
|
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
||||||
|
} else if si.Sequence == sj.Sequence {
|
||||||
|
// If same sequence, compare by name
|
||||||
|
if si.Name > sj.Name {
|
||||||
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
sortedScripts[i], sortedScripts[j] = sortedScripts[j], sortedScripts[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Expected order after sorting
|
// Expected order after sorting (Priority -> Sequence -> Name)
|
||||||
expectedOrder := []string{
|
expectedOrder := []string{
|
||||||
"script3", // Priority 1, Sequence 1
|
"a_script3", // Priority 1, Sequence 1, Name a_script3
|
||||||
"script4", // Priority 1, Sequence 2
|
"b_script4", // Priority 1, Sequence 1, Name b_script4
|
||||||
"script2", // Priority 1, Sequence 3
|
"script2", // Priority 1, Sequence 3
|
||||||
"script1", // Priority 2, Sequence 1
|
"z_script1", // Priority 2, Sequence 1
|
||||||
"script6", // Priority 2, Sequence 2
|
"script6", // Priority 2, Sequence 2
|
||||||
"script5", // Priority 3, Sequence 1
|
"script5", // Priority 3, Sequence 1
|
||||||
}
|
}
|
||||||
@@ -153,6 +163,13 @@ func TestScriptSorting(t *testing.T) {
|
|||||||
t.Errorf("Sequence not ascending at position %d with same priority %d: %d > %d",
|
t.Errorf("Sequence not ascending at position %d with same priority %d: %d > %d",
|
||||||
i, sortedScripts[i].Priority, sortedScripts[i].Sequence, sortedScripts[i+1].Sequence)
|
i, sortedScripts[i].Priority, sortedScripts[i].Sequence, sortedScripts[i+1].Sequence)
|
||||||
}
|
}
|
||||||
|
// Within same priority and sequence, names should be ascending
|
||||||
|
if sortedScripts[i].Priority == sortedScripts[i+1].Priority &&
|
||||||
|
sortedScripts[i].Sequence == sortedScripts[i+1].Sequence &&
|
||||||
|
sortedScripts[i].Name > sortedScripts[i+1].Name {
|
||||||
|
t.Errorf("Name not ascending at position %d with same priority/sequence: %s > %s",
|
||||||
|
i, sortedScripts[i].Name, sortedScripts[i+1].Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user