Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ec9991324 | |||
| a3e45c206d | |||
| 165623bb1d | |||
| 3c20c3c5d9 | |||
| a54594e49b |
@@ -120,8 +120,10 @@ func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema,
|
||||
}
|
||||
|
||||
if tgtTable, exists := existingTables[tableName]; exists {
|
||||
// Table exists, merge its columns
|
||||
// Table exists, merge its columns, constraints, and indexes
|
||||
r.mergeColumns(tgtTable, srcTable)
|
||||
r.mergeConstraints(tgtTable, srcTable)
|
||||
r.mergeIndexes(tgtTable, srcTable)
|
||||
} else {
|
||||
// Table doesn't exist, add it
|
||||
newTable := cloneTable(srcTable)
|
||||
@@ -151,6 +153,50 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize constraints map if nil
|
||||
if table.Constraints == nil {
|
||||
table.Constraints = make(map[string]*models.Constraint)
|
||||
}
|
||||
|
||||
// Create map of existing constraints
|
||||
existingConstraints := make(map[string]*models.Constraint)
|
||||
for constName := range table.Constraints {
|
||||
existingConstraints[constName] = table.Constraints[constName]
|
||||
}
|
||||
|
||||
// Merge constraints
|
||||
for constName, srcConst := range srcTable.Constraints {
|
||||
if _, exists := existingConstraints[constName]; !exists {
|
||||
// Constraint doesn't exist, add it
|
||||
newConst := cloneConstraint(srcConst)
|
||||
table.Constraints[constName] = newConst
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize indexes map if nil
|
||||
if table.Indexes == nil {
|
||||
table.Indexes = make(map[string]*models.Index)
|
||||
}
|
||||
|
||||
// Create map of existing indexes
|
||||
existingIndexes := make(map[string]*models.Index)
|
||||
for idxName := range table.Indexes {
|
||||
existingIndexes[idxName] = table.Indexes[idxName]
|
||||
}
|
||||
|
||||
// Merge indexes
|
||||
for idxName, srcIdx := range srcTable.Indexes {
|
||||
if _, exists := existingIndexes[idxName]; !exists {
|
||||
// Index doesn't exist, add it
|
||||
newIdx := cloneIndex(srcIdx)
|
||||
table.Indexes[idxName] = newIdx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
||||
// Create map of existing views
|
||||
existingViews := make(map[string]*models.View)
|
||||
|
||||
@@ -603,8 +603,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
column.Default = strings.Trim(defaultVal, "'\"")
|
||||
} else if attr == "unique" {
|
||||
// Create a unique constraint
|
||||
// Clean table name by removing leading underscores to avoid double underscores
|
||||
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||
uniqueConstraint := models.InitConstraint(
|
||||
fmt.Sprintf("uq_%s", columnName),
|
||||
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||
models.UniqueConstraint,
|
||||
)
|
||||
uniqueConstraint.Schema = schemaName
|
||||
@@ -652,8 +654,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = []string{columnName}
|
||||
}
|
||||
// Generate short constraint name based on the column
|
||||
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
|
||||
// Generate constraint name based on table and columns
|
||||
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -737,7 +739,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
||||
|
||||
// Generate name if not provided
|
||||
if index.Name == "" {
|
||||
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
|
||||
prefix := "idx"
|
||||
if index.Unique {
|
||||
prefix = "uidx"
|
||||
}
|
||||
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
|
||||
}
|
||||
|
||||
return index
|
||||
@@ -797,10 +803,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate short constraint name based on the source column
|
||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||
if len(fromColumns) > 0 {
|
||||
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
|
||||
// Generate constraint name based on table and columns
|
||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
|
||||
if len(fromColumns) == 0 {
|
||||
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(
|
||||
|
||||
@@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstraintNaming(t *testing.T) {
|
||||
// Test that constraints are named with proper prefixes
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
// Find users table
|
||||
var usersTable *models.Table
|
||||
var postsTable *models.Table
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == "users" {
|
||||
usersTable = table
|
||||
} else if table.Name == "posts" {
|
||||
postsTable = table
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usersTable == nil {
|
||||
t.Fatal("Users table not found")
|
||||
}
|
||||
if postsTable == nil {
|
||||
t.Fatal("Posts table not found")
|
||||
}
|
||||
|
||||
// Test unique constraint naming: ukey_table_column
|
||||
if _, exists := usersTable.Constraints["ukey_users_email"]; !exists {
|
||||
t.Error("Expected unique constraint 'ukey_users_email' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
|
||||
}
|
||||
|
||||
if _, exists := postsTable.Constraints["ukey_posts_slug"]; !exists {
|
||||
t.Error("Expected unique constraint 'ukey_posts_slug' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||
}
|
||||
|
||||
// Test foreign key naming: fk_table_column
|
||||
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
|
||||
t.Error("Expected foreign key 'fk_posts_user_id' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||
}
|
||||
|
||||
// Test unique index naming: uidx_table_columns
|
||||
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
|
||||
t.Error("Expected unique index 'uidx_posts_slug' not found")
|
||||
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||
}
|
||||
|
||||
// Test regular index naming: idx_table_columns
|
||||
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
|
||||
t.Error("Expected index 'idx_posts_user_id_published' not found")
|
||||
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func getKeys[V any](m map[string]V) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func TestHasCommentedRefs(t *testing.T) {
|
||||
// Test with the actual multifile test fixtures
|
||||
tests := []struct {
|
||||
|
||||
@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
||||
t.Error("SQL missing DELETE handling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
|
||||
// Current database (empty)
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("entity")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
// Model database (with constraint starting with number)
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("entity")
|
||||
|
||||
// Create individual_actor_relationship table
|
||||
table := models.InitTable("individual_actor_relationship", "entity")
|
||||
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
|
||||
idCol.Type = "integer"
|
||||
idCol.IsPrimaryKey = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
|
||||
actorIDCol.Type = "integer"
|
||||
table.Columns["actor_id"] = actorIDCol
|
||||
|
||||
// Add constraint with name starting with number
|
||||
constraint := &models.Constraint{
|
||||
Name: "215162_fk_actor",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"actor_id"},
|
||||
ReferencedSchema: "entity",
|
||||
ReferencedTable: "actor",
|
||||
ReferencedColumns: []string{"id"},
|
||||
OnDelete: "CASCADE",
|
||||
OnUpdate: "NO ACTION",
|
||||
}
|
||||
table.Constraints["215162_fk_actor"] = constraint
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
// Generate migration
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
err = writer.WriteMigration(model, current)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
t.Logf("Generated migration:\n%s", output)
|
||||
|
||||
// Verify constraint name is properly quoted
|
||||
if !strings.Contains(output, `"215162_fk_actor"`) {
|
||||
t.Error("Constraint name starting with number should be quoted")
|
||||
}
|
||||
|
||||
// Verify the SQL is syntactically correct (contains required keywords)
|
||||
if !strings.Contains(output, "ADD CONSTRAINT") {
|
||||
t.Error("Migration missing ADD CONSTRAINT")
|
||||
}
|
||||
if !strings.Contains(output, "FOREIGN KEY") {
|
||||
t.Error("Migration missing FOREIGN KEY")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
|
||||
"quote": quote,
|
||||
"escape": escape,
|
||||
"safe_identifier": safeIdentifier,
|
||||
"quote_ident": quoteIdent,
|
||||
|
||||
// Type conversion
|
||||
"goTypeToSQL": goTypeToSQL,
|
||||
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
|
||||
return strings.ToLower(safe)
|
||||
}
|
||||
|
||||
// quoteIdent quotes a PostgreSQL identifier if necessary
|
||||
// Identifiers need quoting if they:
|
||||
// - Start with a digit
|
||||
// - Contain special characters
|
||||
// - Are reserved keywords
|
||||
// - Contain uppercase letters (to preserve case)
|
||||
func quoteIdent(s string) string {
|
||||
if s == "" {
|
||||
return `""`
|
||||
}
|
||||
|
||||
// Check if quoting is needed
|
||||
needsQuoting := unicode.IsDigit(rune(s[0]))
|
||||
|
||||
// Starts with digit
|
||||
|
||||
// Contains uppercase letters or special characters
|
||||
for _, r := range s {
|
||||
if unicode.IsUpper(r) {
|
||||
needsQuoting = true
|
||||
break
|
||||
}
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
|
||||
needsQuoting = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsQuoting {
|
||||
// Escape double quotes by doubling them
|
||||
escaped := strings.ReplaceAll(s, `"`, `""`)
|
||||
return `"` + escaped + `"`
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Type conversion functions
|
||||
|
||||
// goTypeToSQL converts Go type to PostgreSQL type
|
||||
|
||||
@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteIdent(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"valid_name", "valid_name"},
|
||||
{"ValidName", `"ValidName"`},
|
||||
{"123column", `"123column"`},
|
||||
{"215162_fk_constraint", `"215162_fk_constraint"`},
|
||||
{"user-id", `"user-id"`},
|
||||
{"user@domain", `"user@domain"`},
|
||||
{`"quoted"`, `"""quoted"""`},
|
||||
{"", `""`},
|
||||
{"lowercase", "lowercase"},
|
||||
{"with_underscore", "with_underscore"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := quoteIdent(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoTypeToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
|
||||
// Check that all expected functions are registered
|
||||
expectedFuncs := []string{
|
||||
"upper", "lower", "snake_case", "camelCase",
|
||||
"indent", "quote", "escape", "safe_identifier",
|
||||
"indent", "quote", "escape", "safe_identifier", "quote_ident",
|
||||
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
||||
"first", "last", "filter", "mapFunc", "join_with",
|
||||
"join",
|
||||
|
||||
@@ -177,6 +177,72 @@ type AuditTriggerData struct {
|
||||
Events string
|
||||
}
|
||||
|
||||
// CreateUniqueConstraintData contains data for create unique constraint template
|
||||
type CreateUniqueConstraintData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
Columns string
|
||||
}
|
||||
|
||||
// CreateCheckConstraintData contains data for create check constraint template
|
||||
type CreateCheckConstraintData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
Expression string
|
||||
}
|
||||
|
||||
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
|
||||
type CreateForeignKeyWithCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
SourceColumns string
|
||||
TargetSchema string
|
||||
TargetTable string
|
||||
TargetColumns string
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
Deferrable bool
|
||||
}
|
||||
|
||||
// SetSequenceValueData contains data for set sequence value template
|
||||
type SetSequenceValueData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
SequenceName string
|
||||
ColumnName string
|
||||
}
|
||||
|
||||
// CreateSequenceData contains data for create sequence template
|
||||
type CreateSequenceData struct {
|
||||
SchemaName string
|
||||
SequenceName string
|
||||
Increment int
|
||||
MinValue int64
|
||||
MaxValue int64
|
||||
StartValue int64
|
||||
CacheSize int
|
||||
}
|
||||
|
||||
// AddColumnWithCheckData contains data for add column with existence check template
|
||||
type AddColumnWithCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
ColumnDefinition string
|
||||
}
|
||||
|
||||
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
|
||||
type CreatePrimaryKeyWithAutoGenCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
||||
Columns string
|
||||
}
|
||||
|
||||
// Execute methods for each template
|
||||
|
||||
// ExecuteCreateTable executes the create table template
|
||||
@@ -319,6 +385,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateUniqueConstraint executes the create unique constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateCheckConstraint executes the create check constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
|
||||
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteSetSequenceValue executes the set sequence value template
|
||||
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateSequence executes the create sequence template
|
||||
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteAddColumnWithCheck executes the add column with check template
|
||||
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
|
||||
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// Helper functions to build template data from models
|
||||
|
||||
// BuildCreateTableData builds CreateTableData from a models.Table
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
|
||||
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
||||
{{- if .NotNull}} NOT NULL{{end}};
|
||||
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND column_name = '{{.ColumnName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD COLUMN {{.ColumnDefinition}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,7 +1,7 @@
|
||||
{{- if .SetDefault -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||
{{- else -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
|
||||
{{- end -}}
|
||||
@@ -1,2 +1,2 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
|
||||
COMMENT ON COLUMN {{quote_ident .SchemaName}}.{{quote_ident .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
|
||||
COMMENT ON TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} IS '{{.Comment}}';
|
||||
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,10 +1,10 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}}
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
|
||||
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}
|
||||
DEFERRABLE;
|
||||
@@ -0,0 +1,18 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
|
||||
DEFERRABLE{{end}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,2 +1,2 @@
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
|
||||
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
|
||||
ON {{quote_ident .SchemaName}}.{{quote_ident .TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
@@ -6,8 +6,8 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -0,0 +1,27 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
auto_pk_name text;
|
||||
BEGIN
|
||||
-- Drop auto-generated primary key if it exists
|
||||
SELECT constraint_name INTO auto_pk_name
|
||||
FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_type = 'PRIMARY KEY'
|
||||
AND constraint_name IN ({{.AutoGenNames}});
|
||||
|
||||
IF auto_pk_name IS NOT NULL THEN
|
||||
EXECUTE 'ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||
END IF;
|
||||
|
||||
-- Add named primary key if it doesn't exist
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
@@ -0,0 +1,6 @@
|
||||
CREATE SEQUENCE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}
|
||||
INCREMENT {{.Increment}}
|
||||
MINVALUE {{.MinValue}}
|
||||
MAXVALUE {{.MaxValue}}
|
||||
START {{.StartValue}}
|
||||
CACHE {{.CacheSize}};
|
||||
@@ -1,7 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
|
||||
CREATE TABLE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .TableName}} (
|
||||
{{- range $i, $col := .Columns}}
|
||||
{{- if $i}},{{end}}
|
||||
{{$col.Name}} {{$col.Type}}
|
||||
{{quote_ident $col.Name}} {{$col.Type}}
|
||||
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
||||
{{- if $col.NotNull}} NOT NULL{{end}}
|
||||
{{- end}}
|
||||
|
||||
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1 +1 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
@@ -1 +1 @@
|
||||
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
|
||||
DROP INDEX IF EXISTS {{quote_ident .SchemaName}}.{{quote_ident .IndexName}} CASCADE;
|
||||
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
@@ -0,0 +1,19 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
m_cnt bigint;
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM pg_class c
|
||||
INNER JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relname = '{{.SequenceName}}'
|
||||
AND n.nspname = '{{.SchemaName}}'
|
||||
AND c.relkind = 'S'
|
||||
) THEN
|
||||
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
|
||||
FROM {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
INTO m_cnt;
|
||||
|
||||
PERFORM setval('{{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}'::regclass, m_cnt);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -22,6 +22,7 @@ type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
writer io.Writer
|
||||
executionReport *ExecutionReport
|
||||
executor *TemplateExecutor
|
||||
}
|
||||
|
||||
// ExecutionReport tracks the execution status of SQL statements
|
||||
@@ -57,8 +58,10 @@ type ExecutionError struct {
|
||||
|
||||
// NewWriter creates a new PostgreSQL SQL writer
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
executor, _ := NewTemplateExecutor()
|
||||
return &Writer{
|
||||
options: options,
|
||||
executor: executor,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,36 +218,19 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||
}
|
||||
|
||||
// Wrap in DO block to drop auto-generated PK and add our named PK
|
||||
stmt := fmt.Sprintf("DO $$\nDECLARE\n"+
|
||||
" auto_pk_name text;\n"+
|
||||
"BEGIN\n"+
|
||||
" -- Drop auto-generated primary key if it exists\n"+
|
||||
" SELECT constraint_name INTO auto_pk_name\n"+
|
||||
" FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_type = 'PRIMARY KEY'\n"+
|
||||
" AND constraint_name IN (%s);\n"+
|
||||
"\n"+
|
||||
" IF auto_pk_name IS NOT NULL THEN\n"+
|
||||
" EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n"+
|
||||
" END IF;\n"+
|
||||
"\n"+
|
||||
" -- Add named primary key if it doesn't exist\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s);\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, formatStringList(autoGenPKNames),
|
||||
schema.SQLName(), table.SQLName(),
|
||||
schema.Name, table.Name, pkName,
|
||||
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
|
||||
// Use template to generate primary key statement
|
||||
data := CreatePrimaryKeyWithAutoGenCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: pkName,
|
||||
AutoGenNames: formatStringList(autoGenPKNames),
|
||||
Columns: strings.Join(pkColumns, ", "),
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -290,7 +276,53 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
||||
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
uniqueStr, quoteIdentifier(index.Name), schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5.5: Unique constraints
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.UniqueConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use template to generate unique constraint statement
|
||||
data := CreateUniqueConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Columns: strings.Join(constraint.Columns, ", "),
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreateUniqueConstraint(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate unique constraint for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5.7: Check constraints
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.CheckConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use template to generate check constraint statement
|
||||
data := CreateCheckConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Expression: constraint.Expression,
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreateCheckConstraint(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate check constraint for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -317,23 +349,24 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
onUpdate = "NO ACTION"
|
||||
}
|
||||
|
||||
// Wrap in DO block to check for existing constraint
|
||||
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s;\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, constraint.Name,
|
||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
||||
strings.Join(constraint.Columns, ", "),
|
||||
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
||||
strings.Join(constraint.ReferencedColumns, ", "),
|
||||
onDelete, onUpdate)
|
||||
// Use template to generate foreign key statement
|
||||
data := CreateForeignKeyWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
SourceColumns: strings.Join(constraint.Columns, ", "),
|
||||
TargetSchema: refSchema,
|
||||
TargetTable: constraint.ReferencedTable,
|
||||
TargetColumns: strings.Join(constraint.ReferencedColumns, ", "),
|
||||
OnDelete: onDelete,
|
||||
OnUpdate: onUpdate,
|
||||
Deferrable: false,
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -381,19 +414,18 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
|
||||
for _, col := range columns {
|
||||
colDef := w.generateColumnDefinition(col)
|
||||
|
||||
// Generate DO block that checks if column exists before adding
|
||||
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.columns\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND column_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD COLUMN %s;\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, col.Name,
|
||||
schema.SQLName(), table.SQLName(), colDef)
|
||||
// Use template to generate add column statement
|
||||
data := AddColumnWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ColumnName: col.Name,
|
||||
ColumnDefinition: colDef,
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteAddColumnWithCheck(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -542,6 +574,16 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 5.5: Create unique constraints (priority 185)
|
||||
if err := w.writeUniqueConstraints(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 5.7: Create check constraints (priority 190)
|
||||
if err := w.writeCheckConstraints(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 6: Create foreign key constraints (priority 195)
|
||||
if err := w.writeForeignKeys(schema); err != nil {
|
||||
return err
|
||||
@@ -639,13 +681,23 @@ func (w *Writer) writeSequences(schema *models.Schema) error {
|
||||
}
|
||||
|
||||
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
||||
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
|
||||
schema.SQLName(), seqName)
|
||||
fmt.Fprintf(w.writer, " INCREMENT 1\n")
|
||||
fmt.Fprintf(w.writer, " MINVALUE 1\n")
|
||||
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
|
||||
fmt.Fprintf(w.writer, " START 1\n")
|
||||
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
|
||||
|
||||
data := CreateSequenceData{
|
||||
SchemaName: schema.Name,
|
||||
SequenceName: seqName,
|
||||
Increment: 1,
|
||||
MinValue: 1,
|
||||
MaxValue: 9223372036854775807,
|
||||
StartValue: 1,
|
||||
CacheSize: 1,
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteCreateSequence(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate create sequence for %s.%s: %w", schema.Name, seqName, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -687,18 +739,19 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
|
||||
for _, col := range columns {
|
||||
colDef := w.generateColumnDefinition(col)
|
||||
|
||||
// Generate DO block that checks if column exists before adding
|
||||
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.columns\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND column_name = '%s'\n", col.Name)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD COLUMN %s;\n",
|
||||
schema.SQLName(), table.SQLName(), colDef)
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
data := AddColumnWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ColumnName: col.Name,
|
||||
ColumnDefinition: colDef,
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteAddColumnWithCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -752,37 +805,20 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "DO $$\nDECLARE\n")
|
||||
fmt.Fprintf(w.writer, " auto_pk_name text;\nBEGIN\n")
|
||||
data := CreatePrimaryKeyWithAutoGenCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: pkName,
|
||||
AutoGenNames: formatStringList(autoGenPKNames),
|
||||
Columns: strings.Join(columnNames, ", "),
|
||||
}
|
||||
|
||||
// Check for and drop auto-generated primary keys
|
||||
fmt.Fprintf(w.writer, " -- Drop auto-generated primary key if it exists\n")
|
||||
fmt.Fprintf(w.writer, " SELECT constraint_name INTO auto_pk_name\n")
|
||||
fmt.Fprintf(w.writer, " FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_type = 'PRIMARY KEY'\n")
|
||||
fmt.Fprintf(w.writer, " AND constraint_name IN (%s);\n", formatStringList(autoGenPKNames))
|
||||
fmt.Fprintf(w.writer, "\n")
|
||||
fmt.Fprintf(w.writer, " IF auto_pk_name IS NOT NULL THEN\n")
|
||||
fmt.Fprintf(w.writer, " EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n",
|
||||
schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "\n")
|
||||
|
||||
// Add our named primary key if it doesn't exist
|
||||
fmt.Fprintf(w.writer, " -- Add named primary key if it doesn't exist\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
|
||||
pkName, strings.Join(columnNames, ", "))
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -865,6 +901,91 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeUniqueConstraints generates ALTER TABLE statements for unique constraints
|
||||
func (w *Writer) writeUniqueConstraints(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Unique constraints for schema: %s\n", schema.Name)
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
// Sort constraints by name for consistent output
|
||||
constraintNames := make([]string, 0, len(table.Constraints))
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Build column list
|
||||
columnExprs := make([]string, 0, len(constraint.Columns))
|
||||
for _, colName := range constraint.Columns {
|
||||
if col, ok := table.Columns[colName]; ok {
|
||||
columnExprs = append(columnExprs, col.SQLName())
|
||||
}
|
||||
}
|
||||
|
||||
if len(columnExprs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteCreateUniqueConstraint(CreateUniqueConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Columns: strings.Join(columnExprs, ", "),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate unique constraint: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "%s\n\n", sql)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCheckConstraints generates ALTER TABLE statements for check constraints
|
||||
func (w *Writer) writeCheckConstraints(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name)
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
// Sort constraints by name for consistent output
|
||||
constraintNames := make([]string, 0, len(table.Constraints))
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.CheckConstraint {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Skip if expression is empty
|
||||
if constraint.Expression == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteCreateCheckConstraint(CreateCheckConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Expression: constraint.Expression,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate check constraint: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "%s\n\n", sql)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeForeignKeys generates ALTER TABLE statements for foreign keys
|
||||
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
|
||||
@@ -942,24 +1063,103 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
||||
refTable = rel.ToTable
|
||||
}
|
||||
|
||||
// Use DO block to check if constraint exists before adding
|
||||
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", fkName)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
|
||||
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
||||
refSchema, refTable, strings.Join(targetColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
||||
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
|
||||
fmt.Fprintf(w.writer, " DEFERRABLE;\n")
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
// Use template executor to generate foreign key with existence check
|
||||
data := CreateForeignKeyWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: fkName,
|
||||
SourceColumns: strings.Join(sourceColumns, ", "),
|
||||
TargetSchema: refSchema,
|
||||
TargetTable: refTable,
|
||||
TargetColumns: strings.Join(targetColumns, ", "),
|
||||
OnDelete: onDelete,
|
||||
OnUpdate: onUpdate,
|
||||
Deferrable: true,
|
||||
}
|
||||
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
}
|
||||
|
||||
// Also process any foreign key constraints that don't have a relationship
|
||||
processedConstraints := make(map[string]bool)
|
||||
for _, rel := range table.Relationships {
|
||||
fkName := rel.ForeignKey
|
||||
if fkName == "" {
|
||||
fkName = rel.Name
|
||||
}
|
||||
if fkName != "" {
|
||||
processedConstraints[fkName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Find unprocessed foreign key constraints
|
||||
constraintNames := make([]string, 0)
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint && !processedConstraints[name] {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Build column lists
|
||||
sourceColumns := make([]string, 0, len(constraint.Columns))
|
||||
for _, colName := range constraint.Columns {
|
||||
if col, ok := table.Columns[colName]; ok {
|
||||
sourceColumns = append(sourceColumns, col.SQLName())
|
||||
} else {
|
||||
sourceColumns = append(sourceColumns, colName)
|
||||
}
|
||||
}
|
||||
|
||||
targetColumns := make([]string, 0, len(constraint.ReferencedColumns))
|
||||
for _, colName := range constraint.ReferencedColumns {
|
||||
targetColumns = append(targetColumns, strings.ToLower(colName))
|
||||
}
|
||||
|
||||
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
onDelete := "NO ACTION"
|
||||
if constraint.OnDelete != "" {
|
||||
onDelete = strings.ToUpper(constraint.OnDelete)
|
||||
}
|
||||
|
||||
onUpdate := "NO ACTION"
|
||||
if constraint.OnUpdate != "" {
|
||||
onUpdate = strings.ToUpper(constraint.OnUpdate)
|
||||
}
|
||||
|
||||
refSchema := constraint.ReferencedSchema
|
||||
if refSchema == "" {
|
||||
refSchema = schema.Name
|
||||
}
|
||||
refTable := constraint.ReferencedTable
|
||||
|
||||
// Use template executor to generate foreign key with existence check
|
||||
data := CreateForeignKeyWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
SourceColumns: strings.Join(sourceColumns, ", "),
|
||||
TargetSchema: refSchema,
|
||||
TargetTable: refTable,
|
||||
TargetColumns: strings.Join(targetColumns, ", "),
|
||||
OnDelete: onDelete,
|
||||
OnUpdate: onUpdate,
|
||||
Deferrable: false,
|
||||
}
|
||||
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -978,26 +1178,19 @@ func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
|
||||
|
||||
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
||||
|
||||
fmt.Fprintf(w.writer, "DO $$\n")
|
||||
fmt.Fprintf(w.writer, "DECLARE\n")
|
||||
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
|
||||
fmt.Fprintf(w.writer, "BEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
|
||||
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
|
||||
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
|
||||
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
|
||||
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
|
||||
fmt.Fprintf(w.writer, " \n")
|
||||
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
|
||||
schema.SQLName(), seqName)
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n")
|
||||
fmt.Fprintf(w.writer, "$$;\n\n")
|
||||
// Use template executor to generate set sequence value statement
|
||||
data := SetSequenceValueData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
SequenceName: seqName,
|
||||
ColumnName: pk.Name,
|
||||
}
|
||||
sql, err := w.executor.ExecuteSetSequenceValue(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate set sequence value for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1217,7 +1410,8 @@ func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) erro
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements))
|
||||
stmtType := detectStatementType(stmtTrimmed)
|
||||
fmt.Fprintf(os.Stderr, "Executing statement %d/%d [%s]...\n", i+1, len(statements), stmtType)
|
||||
|
||||
_, execErr := conn.Exec(ctx, stmt)
|
||||
if execErr != nil {
|
||||
@@ -1351,3 +1545,94 @@ func truncateStatement(stmt string) string {
|
||||
func getCurrentTimestamp() string {
|
||||
return time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// detectStatementType detects the type of SQL statement for logging
|
||||
func detectStatementType(stmt string) string {
|
||||
upperStmt := strings.ToUpper(stmt)
|
||||
|
||||
// Check for DO blocks (used for conditional DDL)
|
||||
if strings.HasPrefix(upperStmt, "DO $$") || strings.HasPrefix(upperStmt, "DO $") {
|
||||
// Look inside the DO block for the actual operation
|
||||
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD CONSTRAINT") {
|
||||
if strings.Contains(upperStmt, "UNIQUE") {
|
||||
return "ADD UNIQUE CONSTRAINT"
|
||||
} else if strings.Contains(upperStmt, "FOREIGN KEY") {
|
||||
return "ADD FOREIGN KEY"
|
||||
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
|
||||
return "ADD PRIMARY KEY"
|
||||
} else if strings.Contains(upperStmt, "CHECK") {
|
||||
return "ADD CHECK CONSTRAINT"
|
||||
}
|
||||
return "ADD CONSTRAINT"
|
||||
}
|
||||
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD COLUMN") {
|
||||
return "ADD COLUMN"
|
||||
}
|
||||
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
|
||||
return "DROP CONSTRAINT"
|
||||
}
|
||||
return "DO BLOCK"
|
||||
}
|
||||
|
||||
// Direct DDL statements
|
||||
if strings.HasPrefix(upperStmt, "CREATE SCHEMA") {
|
||||
return "CREATE SCHEMA"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE SEQUENCE") {
|
||||
return "CREATE SEQUENCE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE TABLE") {
|
||||
return "CREATE TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE INDEX") {
|
||||
return "CREATE INDEX"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE UNIQUE INDEX") {
|
||||
return "CREATE UNIQUE INDEX"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "ALTER TABLE") {
|
||||
if strings.Contains(upperStmt, "ADD CONSTRAINT") {
|
||||
if strings.Contains(upperStmt, "FOREIGN KEY") {
|
||||
return "ADD FOREIGN KEY"
|
||||
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
|
||||
return "ADD PRIMARY KEY"
|
||||
} else if strings.Contains(upperStmt, "UNIQUE") {
|
||||
return "ADD UNIQUE CONSTRAINT"
|
||||
} else if strings.Contains(upperStmt, "CHECK") {
|
||||
return "ADD CHECK CONSTRAINT"
|
||||
}
|
||||
return "ADD CONSTRAINT"
|
||||
}
|
||||
if strings.Contains(upperStmt, "ADD COLUMN") {
|
||||
return "ADD COLUMN"
|
||||
}
|
||||
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
|
||||
return "DROP CONSTRAINT"
|
||||
}
|
||||
if strings.Contains(upperStmt, "ALTER COLUMN") {
|
||||
return "ALTER COLUMN"
|
||||
}
|
||||
return "ALTER TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "COMMENT ON TABLE") {
|
||||
return "COMMENT ON TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "COMMENT ON COLUMN") {
|
||||
return "COMMENT ON COLUMN"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "DROP TABLE") {
|
||||
return "DROP TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "DROP INDEX") {
|
||||
return "DROP INDEX"
|
||||
}
|
||||
|
||||
// Default
|
||||
return "SQL"
|
||||
}
|
||||
|
||||
// quoteIdentifier wraps an identifier in double quotes if necessary
|
||||
// This is needed for identifiers that start with numbers or contain special characters
|
||||
func quoteIdentifier(s string) string {
|
||||
return quoteIdent(s)
|
||||
}
|
||||
|
||||
@@ -164,6 +164,296 @@ func TestWriteForeignKeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteUniqueConstraints(t *testing.T) {
|
||||
// Create a test database with unique constraints
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create table with unique constraints
|
||||
table := models.InitTable("users", "public")
|
||||
|
||||
// Add columns
|
||||
emailCol := models.InitColumn("email", "users", "public")
|
||||
emailCol.Type = "varchar(255)"
|
||||
emailCol.NotNull = true
|
||||
table.Columns["email"] = emailCol
|
||||
|
||||
guidCol := models.InitColumn("guid", "users", "public")
|
||||
guidCol.Type = "uuid"
|
||||
guidCol.NotNull = true
|
||||
table.Columns["guid"] = guidCol
|
||||
|
||||
// Add unique constraints
|
||||
emailConstraint := &models.Constraint{
|
||||
Name: "uq_email",
|
||||
Type: models.UniqueConstraint,
|
||||
Schema: "public",
|
||||
Table: "users",
|
||||
Columns: []string{"email"},
|
||||
}
|
||||
table.Constraints["uq_email"] = emailConstraint
|
||||
|
||||
guidConstraint := &models.Constraint{
|
||||
Name: "uq_guid",
|
||||
Type: models.UniqueConstraint,
|
||||
Schema: "public",
|
||||
Table: "users",
|
||||
Columns: []string{"guid"},
|
||||
}
|
||||
table.Constraints["uq_guid"] = guidConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify unique constraints are present
|
||||
if !strings.Contains(output, "-- Unique constraints for schema: public") {
|
||||
t.Errorf("Output missing unique constraints header")
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT uq_email UNIQUE (email)") {
|
||||
t.Errorf("Output missing uq_email unique constraint\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT uq_guid UNIQUE (guid)") {
|
||||
t.Errorf("Output missing uq_guid unique constraint\nFull output:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteCheckConstraints(t *testing.T) {
|
||||
// Create a test database with check constraints
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create table with check constraints
|
||||
table := models.InitTable("products", "public")
|
||||
|
||||
// Add columns
|
||||
priceCol := models.InitColumn("price", "products", "public")
|
||||
priceCol.Type = "numeric(10,2)"
|
||||
table.Columns["price"] = priceCol
|
||||
|
||||
statusCol := models.InitColumn("status", "products", "public")
|
||||
statusCol.Type = "varchar(20)"
|
||||
table.Columns["status"] = statusCol
|
||||
|
||||
quantityCol := models.InitColumn("quantity", "products", "public")
|
||||
quantityCol.Type = "integer"
|
||||
table.Columns["quantity"] = quantityCol
|
||||
|
||||
// Add check constraints
|
||||
priceConstraint := &models.Constraint{
|
||||
Name: "ck_price_positive",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "price >= 0",
|
||||
}
|
||||
table.Constraints["ck_price_positive"] = priceConstraint
|
||||
|
||||
statusConstraint := &models.Constraint{
|
||||
Name: "ck_status_valid",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "status IN ('active', 'inactive', 'discontinued')",
|
||||
}
|
||||
table.Constraints["ck_status_valid"] = statusConstraint
|
||||
|
||||
quantityConstraint := &models.Constraint{
|
||||
Name: "ck_quantity_nonnegative",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "quantity >= 0",
|
||||
}
|
||||
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify check constraints are present
|
||||
if !strings.Contains(output, "-- Check constraints for schema: public") {
|
||||
t.Errorf("Output missing check constraints header")
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
|
||||
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
|
||||
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
|
||||
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAllConstraintTypes(t *testing.T) {
|
||||
// Create a comprehensive test with all constraint types
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create orders table
|
||||
ordersTable := models.InitTable("orders", "public")
|
||||
|
||||
// Add columns
|
||||
idCol := models.InitColumn("id", "orders", "public")
|
||||
idCol.Type = "integer"
|
||||
idCol.IsPrimaryKey = true
|
||||
ordersTable.Columns["id"] = idCol
|
||||
|
||||
userIDCol := models.InitColumn("user_id", "orders", "public")
|
||||
userIDCol.Type = "integer"
|
||||
userIDCol.NotNull = true
|
||||
ordersTable.Columns["user_id"] = userIDCol
|
||||
|
||||
orderNumberCol := models.InitColumn("order_number", "orders", "public")
|
||||
orderNumberCol.Type = "varchar(50)"
|
||||
orderNumberCol.NotNull = true
|
||||
ordersTable.Columns["order_number"] = orderNumberCol
|
||||
|
||||
totalCol := models.InitColumn("total", "orders", "public")
|
||||
totalCol.Type = "numeric(10,2)"
|
||||
ordersTable.Columns["total"] = totalCol
|
||||
|
||||
statusCol := models.InitColumn("status", "orders", "public")
|
||||
statusCol.Type = "varchar(20)"
|
||||
ordersTable.Columns["status"] = statusCol
|
||||
|
||||
// Add primary key constraint
|
||||
pkConstraint := &models.Constraint{
|
||||
Name: "pk_orders",
|
||||
Type: models.PrimaryKeyConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"id"},
|
||||
}
|
||||
ordersTable.Constraints["pk_orders"] = pkConstraint
|
||||
|
||||
// Add unique constraint
|
||||
uniqueConstraint := &models.Constraint{
|
||||
Name: "uq_order_number",
|
||||
Type: models.UniqueConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"order_number"},
|
||||
}
|
||||
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
|
||||
|
||||
// Add check constraint
|
||||
checkConstraint := &models.Constraint{
|
||||
Name: "ck_total_positive",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Expression: "total > 0",
|
||||
}
|
||||
ordersTable.Constraints["ck_total_positive"] = checkConstraint
|
||||
|
||||
statusCheckConstraint := &models.Constraint{
|
||||
Name: "ck_status_valid",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Expression: "status IN ('pending', 'completed', 'cancelled')",
|
||||
}
|
||||
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
|
||||
|
||||
// Add foreign key constraint (referencing a users table)
|
||||
fkConstraint := &models.Constraint{
|
||||
Name: "fk_orders_user",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"user_id"},
|
||||
ReferencedSchema: "public",
|
||||
ReferencedTable: "users",
|
||||
ReferencedColumns: []string{"id"},
|
||||
OnDelete: "CASCADE",
|
||||
OnUpdate: "CASCADE",
|
||||
}
|
||||
ordersTable.Constraints["fk_orders_user"] = fkConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, ordersTable)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify all constraint types are present
|
||||
expectedConstraints := map[string]string{
|
||||
"Primary Key": "PRIMARY KEY",
|
||||
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
||||
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
||||
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
||||
"Foreign Key": "FOREIGN KEY",
|
||||
}
|
||||
|
||||
for name, expected := range expectedConstraints {
|
||||
if !strings.Contains(output, expected) {
|
||||
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify section headers
|
||||
sections := []string{
|
||||
"-- Primary keys for schema: public",
|
||||
"-- Unique constraints for schema: public",
|
||||
"-- Check constraints for schema: public",
|
||||
"-- Foreign keys for schema: public",
|
||||
}
|
||||
|
||||
for _, section := range sections {
|
||||
if !strings.Contains(output, section) {
|
||||
t.Errorf("Output missing section header: %s", section)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTable(t *testing.T) {
|
||||
// Create a single table
|
||||
table := models.InitTable("products", "public")
|
||||
|
||||
Reference in New Issue
Block a user