From 3c20c3c5d916c4f0e4f66e3eb695f64ce6f124b0 Mon Sep 17 00:00:00 2001 From: Hein Date: Sat, 31 Jan 2026 20:42:19 +0200 Subject: [PATCH] =?UTF-8?q?feat(writer):=20=F0=9F=8E=89=20Add=20support=20?= =?UTF-8?q?for=20check=20constraints=20in=20schema=20generation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement check constraints in the schema writer. * Generate SQL statements to add check constraints if they do not exist. * Add tests to verify correct generation of check constraints. --- pkg/readers/dbml/reader.go | 20 +-- pkg/readers/dbml/reader_test.go | 70 ++++++++++ pkg/writers/pgsql/writer.go | 150 +++++++++++++++++++++ pkg/writers/pgsql/writer_test.go | 220 +++++++++++++++++++++++++++++++ 4 files changed, 452 insertions(+), 8 deletions(-) diff --git a/pkg/readers/dbml/reader.go b/pkg/readers/dbml/reader.go index 3c6fce5..b005195 100644 --- a/pkg/readers/dbml/reader.go +++ b/pkg/readers/dbml/reader.go @@ -604,7 +604,7 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column } else if attr == "unique" { // Create a unique constraint uniqueConstraint := models.InitConstraint( - fmt.Sprintf("uq_%s", columnName), + fmt.Sprintf("uq_%s_%s", tableName, columnName), models.UniqueConstraint, ) uniqueConstraint.Schema = schemaName @@ -652,8 +652,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column constraint.Table = tableName constraint.Columns = []string{columnName} } - // Generate short constraint name based on the column - constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0]) + // Generate constraint name based on table and columns + constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_")) } } } @@ -737,7 +737,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index { // Generate name if not provided if index.Name == "" { - index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_")) + prefix := "idx" + if index.Unique { + prefix = "uidx" + } + index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_")) } return index @@ -797,10 +801,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint { return nil } - // Generate short constraint name based on the source column - constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable) - if len(fromColumns) > 0 { - constraintName = fmt.Sprintf("fk_%s", fromColumns[0]) + // Generate constraint name based on table and columns + constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_")) + if len(fromColumns) == 0 { + constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable) } constraint := models.InitConstraint( diff --git a/pkg/readers/dbml/reader_test.go b/pkg/readers/dbml/reader_test.go index 21edf0f..8a887e7 100644 --- a/pkg/readers/dbml/reader_test.go +++ b/pkg/readers/dbml/reader_test.go @@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) { } } +func TestConstraintNaming(t *testing.T) { + // Test that constraints are named with proper prefixes + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + // Find users table + var usersTable *models.Table + var postsTable *models.Table + for _, schema := range db.Schemas { + for _, table := range schema.Tables { + if table.Name == "users" { + usersTable = table + } else if table.Name == "posts" { + postsTable = table + } + } + } + + if usersTable == nil { + t.Fatal("Users table not found") + } + if postsTable == nil { + t.Fatal("Posts table not found") + } + + // Test unique constraint naming: uq_table_column + if _, exists := usersTable.Constraints["uq_users_email"]; !exists { + t.Error("Expected unique constraint 'uq_users_email' not found") + t.Logf("Available constraints: %v", getKeys(usersTable.Constraints)) + } + + if _, exists := postsTable.Constraints["uq_posts_slug"]; !exists { + t.Error("Expected unique constraint 'uq_posts_slug' not found") + t.Logf("Available constraints: %v", getKeys(postsTable.Constraints)) + } + + // Test foreign key naming: fk_table_column + if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists { + t.Error("Expected foreign key 'fk_posts_user_id' not found") + t.Logf("Available constraints: %v", getKeys(postsTable.Constraints)) + } + + // Test unique index naming: uidx_table_columns + if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists { + t.Error("Expected unique index 'uidx_posts_slug' not found") + t.Logf("Available indexes: %v", getKeys(postsTable.Indexes)) + } + + // Test regular index naming: idx_table_columns + if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists { + t.Error("Expected index 'idx_posts_user_id_published' not found") + t.Logf("Available indexes: %v", getKeys(postsTable.Indexes)) + } +} + +func getKeys[V any](m map[string]V) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + func TestHasCommentedRefs(t *testing.T) { // Test with the actual multifile test fixtures tests := []struct { diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index f527083..d143007 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -320,6 +320,31 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro } } + // Phase 5.7: Check constraints + for _, table := range schema.Tables { + for _, constraint := range table.Constraints { + if constraint.Type != models.CheckConstraint { + continue + } + + // Wrap in DO block to check for existing constraint + stmt := fmt.Sprintf("DO $$\nBEGIN\n"+ + " IF NOT EXISTS (\n"+ + " SELECT 1 FROM information_schema.table_constraints\n"+ + " WHERE table_schema = '%s'\n"+ + " AND table_name = '%s'\n"+ + " AND constraint_name = '%s'\n"+ + " ) THEN\n"+ + " ALTER TABLE %s.%s ADD CONSTRAINT %s CHECK (%s);\n"+ + " END IF;\n"+ + "END;\n$$", + schema.Name, table.Name, constraint.Name, + schema.SQLName(), table.SQLName(), constraint.Name, + constraint.Expression) + statements = append(statements, stmt) + } + } + // Phase 6: Foreign keys for _, table := range schema.Tables { for _, constraint := range table.Constraints { @@ -572,6 +597,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error { return err } + // Phase 5.7: Create check constraints (priority 190) + if err := w.writeCheckConstraints(schema); err != nil { + return err + } + // Phase 6: Create foreign key constraints (priority 195) if err := w.writeForeignKeys(schema); err != nil { return err @@ -944,6 +974,48 @@ func (w *Writer) writeUniqueConstraints(schema *models.Schema) error { return nil } +// writeCheckConstraints generates ALTER TABLE statements for check constraints +func (w *Writer) writeCheckConstraints(schema *models.Schema) error { + fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name) + + for _, table := range schema.Tables { + // Sort constraints by name for consistent output + constraintNames := make([]string, 0, len(table.Constraints)) + for name, constraint := range table.Constraints { + if constraint.Type == models.CheckConstraint { + constraintNames = append(constraintNames, name) + } + } + sort.Strings(constraintNames) + + for _, name := range constraintNames { + constraint := table.Constraints[name] + + // Skip if expression is empty + if constraint.Expression == "" { + continue + } + + // Wrap in DO block to check for existing constraint + fmt.Fprintf(w.writer, "DO $$\n") + fmt.Fprintf(w.writer, "BEGIN\n") + fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") + fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n") + fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) + fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name) + fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", constraint.Name) + fmt.Fprintf(w.writer, " ) THEN\n") + fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD CONSTRAINT %s CHECK (%s);\n", + schema.SQLName(), table.SQLName(), constraint.Name, constraint.Expression) + fmt.Fprintf(w.writer, " END IF;\n") + fmt.Fprintf(w.writer, "END;\n") + fmt.Fprintf(w.writer, "$$;\n\n") + } + } + + return nil +} + // writeForeignKeys generates ALTER TABLE statements for foreign keys func (w *Writer) writeForeignKeys(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name) @@ -1040,6 +1112,84 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error { fmt.Fprintf(w.writer, " END IF;\n") fmt.Fprintf(w.writer, "END;\n$$;\n\n") } + + // Also process any foreign key constraints that don't have a relationship + processedConstraints := make(map[string]bool) + for _, rel := range table.Relationships { + fkName := rel.ForeignKey + if fkName == "" { + fkName = rel.Name + } + if fkName != "" { + processedConstraints[fkName] = true + } + } + + // Find unprocessed foreign key constraints + constraintNames := make([]string, 0) + for name, constraint := range table.Constraints { + if constraint.Type == models.ForeignKeyConstraint && !processedConstraints[name] { + constraintNames = append(constraintNames, name) + } + } + sort.Strings(constraintNames) + + for _, name := range constraintNames { + constraint := table.Constraints[name] + + // Build column lists + sourceColumns := make([]string, 0, len(constraint.Columns)) + for _, colName := range constraint.Columns { + if col, ok := table.Columns[colName]; ok { + sourceColumns = append(sourceColumns, col.SQLName()) + } else { + sourceColumns = append(sourceColumns, colName) + } + } + + targetColumns := make([]string, 0, len(constraint.ReferencedColumns)) + for _, colName := range constraint.ReferencedColumns { + targetColumns = append(targetColumns, strings.ToLower(colName)) + } + + if len(sourceColumns) == 0 || len(targetColumns) == 0 { + continue + } + + onDelete := "NO ACTION" + if constraint.OnDelete != "" { + onDelete = strings.ToUpper(constraint.OnDelete) + } + + onUpdate := "NO ACTION" + if constraint.OnUpdate != "" { + onUpdate = strings.ToUpper(constraint.OnUpdate) + } + + refSchema := constraint.ReferencedSchema + if refSchema == "" { + refSchema = schema.Name + } + refTable := constraint.ReferencedTable + + // Use DO block to check if constraint exists before adding + fmt.Fprintf(w.writer, "DO $$\nBEGIN\n") + fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") + fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n") + fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) + fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name) + fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", constraint.Name) + fmt.Fprintf(w.writer, " ) THEN\n") + fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) + fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", constraint.Name) + fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", ")) + fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n", + refSchema, refTable, strings.Join(targetColumns, ", ")) + fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete) + fmt.Fprintf(w.writer, " ON UPDATE %s;\n", onUpdate) + fmt.Fprintf(w.writer, " END IF;\n") + fmt.Fprintf(w.writer, "END;\n$$;\n\n") + } } return nil diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go index eb3f9b7..06adcdc 100644 --- a/pkg/writers/pgsql/writer_test.go +++ b/pkg/writers/pgsql/writer_test.go @@ -234,6 +234,226 @@ func TestWriteUniqueConstraints(t *testing.T) { } } +func TestWriteCheckConstraints(t *testing.T) { + // Create a test database with check constraints + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + // Create table with check constraints + table := models.InitTable("products", "public") + + // Add columns + priceCol := models.InitColumn("price", "products", "public") + priceCol.Type = "numeric(10,2)" + table.Columns["price"] = priceCol + + statusCol := models.InitColumn("status", "products", "public") + statusCol.Type = "varchar(20)" + table.Columns["status"] = statusCol + + quantityCol := models.InitColumn("quantity", "products", "public") + quantityCol.Type = "integer" + table.Columns["quantity"] = quantityCol + + // Add check constraints + priceConstraint := &models.Constraint{ + Name: "ck_price_positive", + Type: models.CheckConstraint, + Schema: "public", + Table: "products", + Expression: "price >= 0", + } + table.Constraints["ck_price_positive"] = priceConstraint + + statusConstraint := &models.Constraint{ + Name: "ck_status_valid", + Type: models.CheckConstraint, + Schema: "public", + Table: "products", + Expression: "status IN ('active', 'inactive', 'discontinued')", + } + table.Constraints["ck_status_valid"] = statusConstraint + + quantityConstraint := &models.Constraint{ + Name: "ck_quantity_nonnegative", + Type: models.CheckConstraint, + Schema: "public", + Table: "products", + Expression: "quantity >= 0", + } + table.Constraints["ck_quantity_nonnegative"] = quantityConstraint + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + // Create writer with output to buffer + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + // Write the database + err := writer.WriteDatabase(db) + if err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + + // Print output for debugging + t.Logf("Generated SQL:\n%s", output) + + // Verify check constraints are present + if !strings.Contains(output, "-- Check constraints for schema: public") { + t.Errorf("Output missing check constraints header") + } + if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") { + t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output) + } + if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") { + t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output) + } + if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") { + t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output) + } +} + +func TestWriteAllConstraintTypes(t *testing.T) { + // Create a comprehensive test with all constraint types + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + // Create orders table + ordersTable := models.InitTable("orders", "public") + + // Add columns + idCol := models.InitColumn("id", "orders", "public") + idCol.Type = "integer" + idCol.IsPrimaryKey = true + ordersTable.Columns["id"] = idCol + + userIDCol := models.InitColumn("user_id", "orders", "public") + userIDCol.Type = "integer" + userIDCol.NotNull = true + ordersTable.Columns["user_id"] = userIDCol + + orderNumberCol := models.InitColumn("order_number", "orders", "public") + orderNumberCol.Type = "varchar(50)" + orderNumberCol.NotNull = true + ordersTable.Columns["order_number"] = orderNumberCol + + totalCol := models.InitColumn("total", "orders", "public") + totalCol.Type = "numeric(10,2)" + ordersTable.Columns["total"] = totalCol + + statusCol := models.InitColumn("status", "orders", "public") + statusCol.Type = "varchar(20)" + ordersTable.Columns["status"] = statusCol + + // Add primary key constraint + pkConstraint := &models.Constraint{ + Name: "pk_orders", + Type: models.PrimaryKeyConstraint, + Schema: "public", + Table: "orders", + Columns: []string{"id"}, + } + ordersTable.Constraints["pk_orders"] = pkConstraint + + // Add unique constraint + uniqueConstraint := &models.Constraint{ + Name: "uq_order_number", + Type: models.UniqueConstraint, + Schema: "public", + Table: "orders", + Columns: []string{"order_number"}, + } + ordersTable.Constraints["uq_order_number"] = uniqueConstraint + + // Add check constraint + checkConstraint := &models.Constraint{ + Name: "ck_total_positive", + Type: models.CheckConstraint, + Schema: "public", + Table: "orders", + Expression: "total > 0", + } + ordersTable.Constraints["ck_total_positive"] = checkConstraint + + statusCheckConstraint := &models.Constraint{ + Name: "ck_status_valid", + Type: models.CheckConstraint, + Schema: "public", + Table: "orders", + Expression: "status IN ('pending', 'completed', 'cancelled')", + } + ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint + + // Add foreign key constraint (referencing a users table) + fkConstraint := &models.Constraint{ + Name: "fk_orders_user", + Type: models.ForeignKeyConstraint, + Schema: "public", + Table: "orders", + Columns: []string{"user_id"}, + ReferencedSchema: "public", + ReferencedTable: "users", + ReferencedColumns: []string{"id"}, + OnDelete: "CASCADE", + OnUpdate: "CASCADE", + } + ordersTable.Constraints["fk_orders_user"] = fkConstraint + + schema.Tables = append(schema.Tables, ordersTable) + db.Schemas = append(db.Schemas, schema) + + // Create writer with output to buffer + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + // Write the database + err := writer.WriteDatabase(db) + if err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + + // Print output for debugging + t.Logf("Generated SQL:\n%s", output) + + // Verify all constraint types are present + expectedConstraints := map[string]string{ + "Primary Key": "PRIMARY KEY", + "Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)", + "Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)", + "Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))", + "Foreign Key": "FOREIGN KEY", + } + + for name, expected := range expectedConstraints { + if !strings.Contains(output, expected) { + t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output) + } + } + + // Verify section headers + sections := []string{ + "-- Primary keys for schema: public", + "-- Unique constraints for schema: public", + "-- Check constraints for schema: public", + "-- Foreign keys for schema: public", + } + + for _, section := range sections { + if !strings.Contains(output, section) { + t.Errorf("Output missing section header: %s", section) + } + } +} + func TestWriteTable(t *testing.T) { // Create a single table table := models.InitTable("products", "public")