feat(writer): 🎉 Add support for check constraints in schema generation
All checks were successful
CI / Test (1.24) (push) Successful in -26m17s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m41s
CI / Lint (push) Successful in -26m32s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m13s
All checks were successful
CI / Test (1.24) (push) Successful in -26m17s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m41s
CI / Lint (push) Successful in -26m32s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Implement check constraints in the schema writer. * Generate SQL statements to add check constraints if they do not exist. * Add tests to verify correct generation of check constraints.
This commit is contained in:
@@ -320,6 +320,31 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5.7: Check constraints
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.CheckConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Wrap in DO block to check for existing constraint
|
||||
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD CONSTRAINT %s CHECK (%s);\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, constraint.Name,
|
||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
||||
constraint.Expression)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 6: Foreign keys
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
@@ -572,6 +597,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 5.7: Create check constraints (priority 190)
|
||||
if err := w.writeCheckConstraints(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 6: Create foreign key constraints (priority 195)
|
||||
if err := w.writeForeignKeys(schema); err != nil {
|
||||
return err
|
||||
@@ -944,6 +974,48 @@ func (w *Writer) writeUniqueConstraints(schema *models.Schema) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCheckConstraints generates ALTER TABLE statements for check constraints
|
||||
func (w *Writer) writeCheckConstraints(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name)
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
// Sort constraints by name for consistent output
|
||||
constraintNames := make([]string, 0, len(table.Constraints))
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.CheckConstraint {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Skip if expression is empty
|
||||
if constraint.Expression == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Wrap in DO block to check for existing constraint
|
||||
fmt.Fprintf(w.writer, "DO $$\n")
|
||||
fmt.Fprintf(w.writer, "BEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", constraint.Name)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD CONSTRAINT %s CHECK (%s);\n",
|
||||
schema.SQLName(), table.SQLName(), constraint.Name, constraint.Expression)
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n")
|
||||
fmt.Fprintf(w.writer, "$$;\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeForeignKeys generates ALTER TABLE statements for foreign keys
|
||||
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
|
||||
@@ -1040,6 +1112,84 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
}
|
||||
|
||||
// Also process any foreign key constraints that don't have a relationship
|
||||
processedConstraints := make(map[string]bool)
|
||||
for _, rel := range table.Relationships {
|
||||
fkName := rel.ForeignKey
|
||||
if fkName == "" {
|
||||
fkName = rel.Name
|
||||
}
|
||||
if fkName != "" {
|
||||
processedConstraints[fkName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Find unprocessed foreign key constraints
|
||||
constraintNames := make([]string, 0)
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint && !processedConstraints[name] {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Build column lists
|
||||
sourceColumns := make([]string, 0, len(constraint.Columns))
|
||||
for _, colName := range constraint.Columns {
|
||||
if col, ok := table.Columns[colName]; ok {
|
||||
sourceColumns = append(sourceColumns, col.SQLName())
|
||||
} else {
|
||||
sourceColumns = append(sourceColumns, colName)
|
||||
}
|
||||
}
|
||||
|
||||
targetColumns := make([]string, 0, len(constraint.ReferencedColumns))
|
||||
for _, colName := range constraint.ReferencedColumns {
|
||||
targetColumns = append(targetColumns, strings.ToLower(colName))
|
||||
}
|
||||
|
||||
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
onDelete := "NO ACTION"
|
||||
if constraint.OnDelete != "" {
|
||||
onDelete = strings.ToUpper(constraint.OnDelete)
|
||||
}
|
||||
|
||||
onUpdate := "NO ACTION"
|
||||
if constraint.OnUpdate != "" {
|
||||
onUpdate = strings.ToUpper(constraint.OnUpdate)
|
||||
}
|
||||
|
||||
refSchema := constraint.ReferencedSchema
|
||||
if refSchema == "" {
|
||||
refSchema = schema.Name
|
||||
}
|
||||
refTable := constraint.ReferencedTable
|
||||
|
||||
// Use DO block to check if constraint exists before adding
|
||||
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", constraint.Name)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", constraint.Name)
|
||||
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
||||
refSchema, refTable, strings.Join(targetColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
||||
fmt.Fprintf(w.writer, " ON UPDATE %s;\n", onUpdate)
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -234,6 +234,226 @@ func TestWriteUniqueConstraints(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteCheckConstraints(t *testing.T) {
|
||||
// Create a test database with check constraints
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create table with check constraints
|
||||
table := models.InitTable("products", "public")
|
||||
|
||||
// Add columns
|
||||
priceCol := models.InitColumn("price", "products", "public")
|
||||
priceCol.Type = "numeric(10,2)"
|
||||
table.Columns["price"] = priceCol
|
||||
|
||||
statusCol := models.InitColumn("status", "products", "public")
|
||||
statusCol.Type = "varchar(20)"
|
||||
table.Columns["status"] = statusCol
|
||||
|
||||
quantityCol := models.InitColumn("quantity", "products", "public")
|
||||
quantityCol.Type = "integer"
|
||||
table.Columns["quantity"] = quantityCol
|
||||
|
||||
// Add check constraints
|
||||
priceConstraint := &models.Constraint{
|
||||
Name: "ck_price_positive",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "price >= 0",
|
||||
}
|
||||
table.Constraints["ck_price_positive"] = priceConstraint
|
||||
|
||||
statusConstraint := &models.Constraint{
|
||||
Name: "ck_status_valid",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "status IN ('active', 'inactive', 'discontinued')",
|
||||
}
|
||||
table.Constraints["ck_status_valid"] = statusConstraint
|
||||
|
||||
quantityConstraint := &models.Constraint{
|
||||
Name: "ck_quantity_nonnegative",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "quantity >= 0",
|
||||
}
|
||||
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify check constraints are present
|
||||
if !strings.Contains(output, "-- Check constraints for schema: public") {
|
||||
t.Errorf("Output missing check constraints header")
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
|
||||
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
|
||||
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
|
||||
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAllConstraintTypes(t *testing.T) {
|
||||
// Create a comprehensive test with all constraint types
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create orders table
|
||||
ordersTable := models.InitTable("orders", "public")
|
||||
|
||||
// Add columns
|
||||
idCol := models.InitColumn("id", "orders", "public")
|
||||
idCol.Type = "integer"
|
||||
idCol.IsPrimaryKey = true
|
||||
ordersTable.Columns["id"] = idCol
|
||||
|
||||
userIDCol := models.InitColumn("user_id", "orders", "public")
|
||||
userIDCol.Type = "integer"
|
||||
userIDCol.NotNull = true
|
||||
ordersTable.Columns["user_id"] = userIDCol
|
||||
|
||||
orderNumberCol := models.InitColumn("order_number", "orders", "public")
|
||||
orderNumberCol.Type = "varchar(50)"
|
||||
orderNumberCol.NotNull = true
|
||||
ordersTable.Columns["order_number"] = orderNumberCol
|
||||
|
||||
totalCol := models.InitColumn("total", "orders", "public")
|
||||
totalCol.Type = "numeric(10,2)"
|
||||
ordersTable.Columns["total"] = totalCol
|
||||
|
||||
statusCol := models.InitColumn("status", "orders", "public")
|
||||
statusCol.Type = "varchar(20)"
|
||||
ordersTable.Columns["status"] = statusCol
|
||||
|
||||
// Add primary key constraint
|
||||
pkConstraint := &models.Constraint{
|
||||
Name: "pk_orders",
|
||||
Type: models.PrimaryKeyConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"id"},
|
||||
}
|
||||
ordersTable.Constraints["pk_orders"] = pkConstraint
|
||||
|
||||
// Add unique constraint
|
||||
uniqueConstraint := &models.Constraint{
|
||||
Name: "uq_order_number",
|
||||
Type: models.UniqueConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"order_number"},
|
||||
}
|
||||
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
|
||||
|
||||
// Add check constraint
|
||||
checkConstraint := &models.Constraint{
|
||||
Name: "ck_total_positive",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Expression: "total > 0",
|
||||
}
|
||||
ordersTable.Constraints["ck_total_positive"] = checkConstraint
|
||||
|
||||
statusCheckConstraint := &models.Constraint{
|
||||
Name: "ck_status_valid",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Expression: "status IN ('pending', 'completed', 'cancelled')",
|
||||
}
|
||||
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
|
||||
|
||||
// Add foreign key constraint (referencing a users table)
|
||||
fkConstraint := &models.Constraint{
|
||||
Name: "fk_orders_user",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"user_id"},
|
||||
ReferencedSchema: "public",
|
||||
ReferencedTable: "users",
|
||||
ReferencedColumns: []string{"id"},
|
||||
OnDelete: "CASCADE",
|
||||
OnUpdate: "CASCADE",
|
||||
}
|
||||
ordersTable.Constraints["fk_orders_user"] = fkConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, ordersTable)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify all constraint types are present
|
||||
expectedConstraints := map[string]string{
|
||||
"Primary Key": "PRIMARY KEY",
|
||||
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
||||
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
||||
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
||||
"Foreign Key": "FOREIGN KEY",
|
||||
}
|
||||
|
||||
for name, expected := range expectedConstraints {
|
||||
if !strings.Contains(output, expected) {
|
||||
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify section headers
|
||||
sections := []string{
|
||||
"-- Primary keys for schema: public",
|
||||
"-- Unique constraints for schema: public",
|
||||
"-- Check constraints for schema: public",
|
||||
"-- Foreign keys for schema: public",
|
||||
}
|
||||
|
||||
for _, section := range sections {
|
||||
if !strings.Contains(output, section) {
|
||||
t.Errorf("Output missing section header: %s", section)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTable(t *testing.T) {
|
||||
// Create a single table
|
||||
table := models.InitTable("products", "public")
|
||||
|
||||
Reference in New Issue
Block a user