All checks were successful
CI / Test (1.24) (push) Successful in -26m17s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m41s
CI / Lint (push) Successful in -26m32s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Implement check constraints in the schema writer. * Generate SQL statements to add check constraints if they do not exist. * Add tests to verify correct generation of check constraints.
858 lines
25 KiB
Go
858 lines
25 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"bytes"
|
|
"strings"
|
|
"testing"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
func TestWriteDatabase(t *testing.T) {
|
|
// Create a test database
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create a test table
|
|
table := models.InitTable("users", "public")
|
|
table.Description = "User accounts table"
|
|
|
|
// Add columns
|
|
idCol := models.InitColumn("id", "users", "public")
|
|
idCol.Type = "integer"
|
|
idCol.Description = "Primary key"
|
|
idCol.Default = "nextval('public.identity_users_id'::regclass)"
|
|
table.Columns["id"] = idCol
|
|
|
|
nameCol := models.InitColumn("name", "users", "public")
|
|
nameCol.Type = "text"
|
|
nameCol.Description = "User name"
|
|
table.Columns["name"] = nameCol
|
|
|
|
emailCol := models.InitColumn("email", "users", "public")
|
|
emailCol.Type = "text"
|
|
emailCol.Description = "Email address"
|
|
table.Columns["email"] = emailCol
|
|
|
|
// Add primary key constraint
|
|
pkConstraint := &models.Constraint{
|
|
Name: "pk_users",
|
|
Type: models.PrimaryKeyConstraint,
|
|
Columns: []string{"id"},
|
|
}
|
|
table.Constraints["pk_users"] = pkConstraint
|
|
|
|
// Add unique index
|
|
uniqueEmailIndex := &models.Index{
|
|
Name: "uidx_users_email",
|
|
Unique: true,
|
|
Columns: []string{"email"},
|
|
}
|
|
table.Indexes["uidx_users_email"] = uniqueEmailIndex
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify output contains expected elements
|
|
expectedStrings := []string{
|
|
"CREATE TABLE",
|
|
"PRIMARY KEY",
|
|
"UNIQUE INDEX",
|
|
"COMMENT ON TABLE",
|
|
"COMMENT ON COLUMN",
|
|
}
|
|
|
|
for _, expected := range expectedStrings {
|
|
if !strings.Contains(output, expected) {
|
|
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteForeignKeys(t *testing.T) {
|
|
// Create a test database with two related tables
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create parent table (users)
|
|
usersTable := models.InitTable("users", "public")
|
|
idCol := models.InitColumn("id", "users", "public")
|
|
idCol.Type = "integer"
|
|
usersTable.Columns["id"] = idCol
|
|
|
|
// Create child table (posts)
|
|
postsTable := models.InitTable("posts", "public")
|
|
postIdCol := models.InitColumn("id", "posts", "public")
|
|
postIdCol.Type = "integer"
|
|
postsTable.Columns["id"] = postIdCol
|
|
|
|
userIdCol := models.InitColumn("user_id", "posts", "public")
|
|
userIdCol.Type = "integer"
|
|
postsTable.Columns["user_id"] = userIdCol
|
|
|
|
// Add foreign key constraint
|
|
fkConstraint := &models.Constraint{
|
|
Name: "fk_posts_users",
|
|
Type: models.ForeignKeyConstraint,
|
|
Columns: []string{"user_id"},
|
|
ReferencedTable: "users",
|
|
ReferencedSchema: "public",
|
|
ReferencedColumns: []string{"id"},
|
|
OnDelete: "CASCADE",
|
|
OnUpdate: "CASCADE",
|
|
}
|
|
postsTable.Constraints["fk_posts_users"] = fkConstraint
|
|
|
|
// Add relationship
|
|
relationship := &models.Relationship{
|
|
Name: "fk_posts_users",
|
|
FromTable: "posts",
|
|
FromSchema: "public",
|
|
ToTable: "users",
|
|
ToSchema: "public",
|
|
ForeignKey: "fk_posts_users",
|
|
}
|
|
postsTable.Relationships["fk_posts_users"] = relationship
|
|
|
|
schema.Tables = append(schema.Tables, usersTable, postsTable)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify foreign key is present
|
|
if !strings.Contains(output, "FOREIGN KEY") {
|
|
t.Errorf("Output missing FOREIGN KEY statement\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ON DELETE CASCADE") {
|
|
t.Errorf("Output missing ON DELETE CASCADE\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ON UPDATE CASCADE") {
|
|
t.Errorf("Output missing ON UPDATE CASCADE\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestWriteUniqueConstraints(t *testing.T) {
|
|
// Create a test database with unique constraints
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create table with unique constraints
|
|
table := models.InitTable("users", "public")
|
|
|
|
// Add columns
|
|
emailCol := models.InitColumn("email", "users", "public")
|
|
emailCol.Type = "varchar(255)"
|
|
emailCol.NotNull = true
|
|
table.Columns["email"] = emailCol
|
|
|
|
guidCol := models.InitColumn("guid", "users", "public")
|
|
guidCol.Type = "uuid"
|
|
guidCol.NotNull = true
|
|
table.Columns["guid"] = guidCol
|
|
|
|
// Add unique constraints
|
|
emailConstraint := &models.Constraint{
|
|
Name: "uq_email",
|
|
Type: models.UniqueConstraint,
|
|
Schema: "public",
|
|
Table: "users",
|
|
Columns: []string{"email"},
|
|
}
|
|
table.Constraints["uq_email"] = emailConstraint
|
|
|
|
guidConstraint := &models.Constraint{
|
|
Name: "uq_guid",
|
|
Type: models.UniqueConstraint,
|
|
Schema: "public",
|
|
Table: "users",
|
|
Columns: []string{"guid"},
|
|
}
|
|
table.Constraints["uq_guid"] = guidConstraint
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify unique constraints are present
|
|
if !strings.Contains(output, "-- Unique constraints for schema: public") {
|
|
t.Errorf("Output missing unique constraints header")
|
|
}
|
|
if !strings.Contains(output, "ADD CONSTRAINT uq_email UNIQUE (email)") {
|
|
t.Errorf("Output missing uq_email unique constraint\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ADD CONSTRAINT uq_guid UNIQUE (guid)") {
|
|
t.Errorf("Output missing uq_guid unique constraint\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestWriteCheckConstraints(t *testing.T) {
|
|
// Create a test database with check constraints
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create table with check constraints
|
|
table := models.InitTable("products", "public")
|
|
|
|
// Add columns
|
|
priceCol := models.InitColumn("price", "products", "public")
|
|
priceCol.Type = "numeric(10,2)"
|
|
table.Columns["price"] = priceCol
|
|
|
|
statusCol := models.InitColumn("status", "products", "public")
|
|
statusCol.Type = "varchar(20)"
|
|
table.Columns["status"] = statusCol
|
|
|
|
quantityCol := models.InitColumn("quantity", "products", "public")
|
|
quantityCol.Type = "integer"
|
|
table.Columns["quantity"] = quantityCol
|
|
|
|
// Add check constraints
|
|
priceConstraint := &models.Constraint{
|
|
Name: "ck_price_positive",
|
|
Type: models.CheckConstraint,
|
|
Schema: "public",
|
|
Table: "products",
|
|
Expression: "price >= 0",
|
|
}
|
|
table.Constraints["ck_price_positive"] = priceConstraint
|
|
|
|
statusConstraint := &models.Constraint{
|
|
Name: "ck_status_valid",
|
|
Type: models.CheckConstraint,
|
|
Schema: "public",
|
|
Table: "products",
|
|
Expression: "status IN ('active', 'inactive', 'discontinued')",
|
|
}
|
|
table.Constraints["ck_status_valid"] = statusConstraint
|
|
|
|
quantityConstraint := &models.Constraint{
|
|
Name: "ck_quantity_nonnegative",
|
|
Type: models.CheckConstraint,
|
|
Schema: "public",
|
|
Table: "products",
|
|
Expression: "quantity >= 0",
|
|
}
|
|
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify check constraints are present
|
|
if !strings.Contains(output, "-- Check constraints for schema: public") {
|
|
t.Errorf("Output missing check constraints header")
|
|
}
|
|
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
|
|
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
|
|
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
|
|
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestWriteAllConstraintTypes(t *testing.T) {
|
|
// Create a comprehensive test with all constraint types
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create orders table
|
|
ordersTable := models.InitTable("orders", "public")
|
|
|
|
// Add columns
|
|
idCol := models.InitColumn("id", "orders", "public")
|
|
idCol.Type = "integer"
|
|
idCol.IsPrimaryKey = true
|
|
ordersTable.Columns["id"] = idCol
|
|
|
|
userIDCol := models.InitColumn("user_id", "orders", "public")
|
|
userIDCol.Type = "integer"
|
|
userIDCol.NotNull = true
|
|
ordersTable.Columns["user_id"] = userIDCol
|
|
|
|
orderNumberCol := models.InitColumn("order_number", "orders", "public")
|
|
orderNumberCol.Type = "varchar(50)"
|
|
orderNumberCol.NotNull = true
|
|
ordersTable.Columns["order_number"] = orderNumberCol
|
|
|
|
totalCol := models.InitColumn("total", "orders", "public")
|
|
totalCol.Type = "numeric(10,2)"
|
|
ordersTable.Columns["total"] = totalCol
|
|
|
|
statusCol := models.InitColumn("status", "orders", "public")
|
|
statusCol.Type = "varchar(20)"
|
|
ordersTable.Columns["status"] = statusCol
|
|
|
|
// Add primary key constraint
|
|
pkConstraint := &models.Constraint{
|
|
Name: "pk_orders",
|
|
Type: models.PrimaryKeyConstraint,
|
|
Schema: "public",
|
|
Table: "orders",
|
|
Columns: []string{"id"},
|
|
}
|
|
ordersTable.Constraints["pk_orders"] = pkConstraint
|
|
|
|
// Add unique constraint
|
|
uniqueConstraint := &models.Constraint{
|
|
Name: "uq_order_number",
|
|
Type: models.UniqueConstraint,
|
|
Schema: "public",
|
|
Table: "orders",
|
|
Columns: []string{"order_number"},
|
|
}
|
|
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
|
|
|
|
// Add check constraint
|
|
checkConstraint := &models.Constraint{
|
|
Name: "ck_total_positive",
|
|
Type: models.CheckConstraint,
|
|
Schema: "public",
|
|
Table: "orders",
|
|
Expression: "total > 0",
|
|
}
|
|
ordersTable.Constraints["ck_total_positive"] = checkConstraint
|
|
|
|
statusCheckConstraint := &models.Constraint{
|
|
Name: "ck_status_valid",
|
|
Type: models.CheckConstraint,
|
|
Schema: "public",
|
|
Table: "orders",
|
|
Expression: "status IN ('pending', 'completed', 'cancelled')",
|
|
}
|
|
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
|
|
|
|
// Add foreign key constraint (referencing a users table)
|
|
fkConstraint := &models.Constraint{
|
|
Name: "fk_orders_user",
|
|
Type: models.ForeignKeyConstraint,
|
|
Schema: "public",
|
|
Table: "orders",
|
|
Columns: []string{"user_id"},
|
|
ReferencedSchema: "public",
|
|
ReferencedTable: "users",
|
|
ReferencedColumns: []string{"id"},
|
|
OnDelete: "CASCADE",
|
|
OnUpdate: "CASCADE",
|
|
}
|
|
ordersTable.Constraints["fk_orders_user"] = fkConstraint
|
|
|
|
schema.Tables = append(schema.Tables, ordersTable)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify all constraint types are present
|
|
expectedConstraints := map[string]string{
|
|
"Primary Key": "PRIMARY KEY",
|
|
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
|
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
|
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
|
"Foreign Key": "FOREIGN KEY",
|
|
}
|
|
|
|
for name, expected := range expectedConstraints {
|
|
if !strings.Contains(output, expected) {
|
|
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
|
|
}
|
|
}
|
|
|
|
// Verify section headers
|
|
sections := []string{
|
|
"-- Primary keys for schema: public",
|
|
"-- Unique constraints for schema: public",
|
|
"-- Check constraints for schema: public",
|
|
"-- Foreign keys for schema: public",
|
|
}
|
|
|
|
for _, section := range sections {
|
|
if !strings.Contains(output, section) {
|
|
t.Errorf("Output missing section header: %s", section)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteTable(t *testing.T) {
|
|
// Create a single table
|
|
table := models.InitTable("products", "public")
|
|
|
|
idCol := models.InitColumn("id", "products", "public")
|
|
idCol.Type = "integer"
|
|
table.Columns["id"] = idCol
|
|
|
|
nameCol := models.InitColumn("name", "products", "public")
|
|
nameCol.Type = "text"
|
|
table.Columns["name"] = nameCol
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the table
|
|
err := writer.WriteTable(table)
|
|
if err != nil {
|
|
t.Fatalf("WriteTable failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Verify output contains table creation
|
|
if !strings.Contains(output, "CREATE TABLE") {
|
|
t.Error("Output missing CREATE TABLE statement")
|
|
}
|
|
if !strings.Contains(output, "products") {
|
|
t.Error("Output missing table name 'products'")
|
|
}
|
|
}
|
|
|
|
func TestEscapeQuote(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"simple text", "simple text"},
|
|
{"text with 'quote'", "text with ''quote''"},
|
|
{"multiple 'quotes' here", "multiple ''quotes'' here"},
|
|
{"", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := escapeQuote(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("escapeQuote(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIsIntegerType(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected bool
|
|
}{
|
|
{"integer", true},
|
|
{"INTEGER", true},
|
|
{"bigint", true},
|
|
{"smallint", true},
|
|
{"serial", true},
|
|
{"bigserial", true},
|
|
{"text", false},
|
|
{"varchar", false},
|
|
{"uuid", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := isIntegerType(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("isIntegerType(%q) = %v, want %v", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTypeConversion(t *testing.T) {
|
|
// Test that invalid Go types are converted to valid PostgreSQL types
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create a test table with Go types instead of SQL types
|
|
table := models.InitTable("test_types", "public")
|
|
|
|
// Add columns with Go types (invalid for PostgreSQL)
|
|
stringCol := models.InitColumn("name", "test_types", "public")
|
|
stringCol.Type = "string" // Should be converted to "text"
|
|
table.Columns["name"] = stringCol
|
|
|
|
int64Col := models.InitColumn("big_id", "test_types", "public")
|
|
int64Col.Type = "int64" // Should be converted to "bigint"
|
|
table.Columns["big_id"] = int64Col
|
|
|
|
int16Col := models.InitColumn("small_id", "test_types", "public")
|
|
int16Col.Type = "int16" // Should be converted to "smallint"
|
|
table.Columns["small_id"] = int16Col
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify that Go types were converted to PostgreSQL types
|
|
if strings.Contains(output, "string") {
|
|
t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output)
|
|
}
|
|
if strings.Contains(output, "int64") {
|
|
t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output)
|
|
}
|
|
if strings.Contains(output, "int16") {
|
|
t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output)
|
|
}
|
|
|
|
// Verify correct PostgreSQL types are present
|
|
if !strings.Contains(output, "text") {
|
|
t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "bigint") {
|
|
t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "smallint") {
|
|
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestPrimaryKeyExistenceCheck(t *testing.T) {
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
table := models.InitTable("products", "public")
|
|
|
|
idCol := models.InitColumn("id", "products", "public")
|
|
idCol.Type = "integer"
|
|
idCol.IsPrimaryKey = true
|
|
table.Columns["id"] = idCol
|
|
|
|
nameCol := models.InitColumn("name", "products", "public")
|
|
nameCol.Type = "text"
|
|
table.Columns["name"] = nameCol
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify our naming convention is used
|
|
if !strings.Contains(output, "pk_public_products") {
|
|
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
|
|
}
|
|
|
|
// Verify it drops auto-generated primary keys
|
|
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
|
|
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
|
}
|
|
|
|
// Verify it checks for our specific named constraint before adding it
|
|
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
|
|
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestColumnSizeSpecifiers(t *testing.T) {
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
table := models.InitTable("test_sizes", "public")
|
|
|
|
// Integer with invalid size specifier - should ignore size
|
|
integerCol := models.InitColumn("int_col", "test_sizes", "public")
|
|
integerCol.Type = "integer"
|
|
integerCol.Length = 32
|
|
table.Columns["int_col"] = integerCol
|
|
|
|
// Bigint with invalid size specifier - should ignore size
|
|
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
|
|
bigintCol.Type = "bigint"
|
|
bigintCol.Length = 64
|
|
table.Columns["bigint_col"] = bigintCol
|
|
|
|
// Smallint with invalid size specifier - should ignore size
|
|
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
|
|
smallintCol.Type = "smallint"
|
|
smallintCol.Length = 16
|
|
table.Columns["smallint_col"] = smallintCol
|
|
|
|
// Text with length - should convert to varchar
|
|
textCol := models.InitColumn("text_col", "test_sizes", "public")
|
|
textCol.Type = "text"
|
|
textCol.Length = 100
|
|
table.Columns["text_col"] = textCol
|
|
|
|
// Varchar with length - should keep varchar with length
|
|
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
|
|
varcharCol.Type = "varchar"
|
|
varcharCol.Length = 50
|
|
table.Columns["varchar_col"] = varcharCol
|
|
|
|
// Decimal with precision and scale - should keep them
|
|
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
|
|
decimalCol.Type = "decimal"
|
|
decimalCol.Precision = 19
|
|
decimalCol.Scale = 4
|
|
table.Columns["decimal_col"] = decimalCol
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify invalid size specifiers are NOT present
|
|
invalidPatterns := []string{
|
|
"integer(32)",
|
|
"bigint(64)",
|
|
"smallint(16)",
|
|
"text(100)",
|
|
}
|
|
for _, pattern := range invalidPatterns {
|
|
if strings.Contains(output, pattern) {
|
|
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
|
|
}
|
|
}
|
|
|
|
// Verify valid patterns ARE present
|
|
validPatterns := []string{
|
|
"integer", // without size
|
|
"bigint", // without size
|
|
"smallint", // without size
|
|
"varchar(100)", // text converted to varchar with length
|
|
"varchar(50)", // varchar with length
|
|
"decimal(19,4)", // decimal with precision and scale
|
|
}
|
|
for _, pattern := range validPatterns {
|
|
if !strings.Contains(output, pattern) {
|
|
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGenerateAddColumnStatements(t *testing.T) {
|
|
// Create a test database with tables that have new columns
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create a table with columns
|
|
table := models.InitTable("users", "public")
|
|
|
|
// Existing column
|
|
idCol := models.InitColumn("id", "users", "public")
|
|
idCol.Type = "integer"
|
|
idCol.NotNull = true
|
|
idCol.Sequence = 1
|
|
table.Columns["id"] = idCol
|
|
|
|
// New column to be added
|
|
emailCol := models.InitColumn("email", "users", "public")
|
|
emailCol.Type = "varchar"
|
|
emailCol.Length = 255
|
|
emailCol.NotNull = true
|
|
emailCol.Sequence = 2
|
|
table.Columns["email"] = emailCol
|
|
|
|
// New column with default
|
|
statusCol := models.InitColumn("status", "users", "public")
|
|
statusCol.Type = "text"
|
|
statusCol.Default = "active"
|
|
statusCol.Sequence = 3
|
|
table.Columns["status"] = statusCol
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
|
|
// Generate ADD COLUMN statements
|
|
statements, err := writer.GenerateAddColumnsForDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
|
|
}
|
|
|
|
// Join all statements to verify content
|
|
output := strings.Join(statements, "\n")
|
|
t.Logf("Generated ADD COLUMN statements:\n%s", output)
|
|
|
|
// Verify expected elements
|
|
expectedStrings := []string{
|
|
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
|
|
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
|
|
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
|
|
"information_schema.columns",
|
|
"table_schema = 'public'",
|
|
"table_name = 'users'",
|
|
"column_name = 'id'",
|
|
"column_name = 'email'",
|
|
"column_name = 'status'",
|
|
}
|
|
|
|
for _, expected := range expectedStrings {
|
|
if !strings.Contains(output, expected) {
|
|
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
|
|
}
|
|
}
|
|
|
|
// Verify DO blocks are present for conditional adds
|
|
doBlockCount := strings.Count(output, "DO $$")
|
|
if doBlockCount < 3 {
|
|
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
|
|
}
|
|
|
|
// Verify IF NOT EXISTS logic
|
|
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
|
|
if ifNotExistsCount < 3 {
|
|
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
|
|
}
|
|
}
|
|
|
|
func TestWriteAddColumnStatements(t *testing.T) {
|
|
// Create a test database
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create a table with a new column to be added
|
|
table := models.InitTable("products", "public")
|
|
|
|
idCol := models.InitColumn("id", "products", "public")
|
|
idCol.Type = "integer"
|
|
table.Columns["id"] = idCol
|
|
|
|
// New column with various properties
|
|
descCol := models.InitColumn("description", "products", "public")
|
|
descCol.Type = "text"
|
|
descCol.NotNull = false
|
|
table.Columns["description"] = descCol
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write ADD COLUMN statements
|
|
err := writer.WriteAddColumnStatements(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteAddColumnStatements failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
t.Logf("Generated output:\n%s", output)
|
|
|
|
// Verify output contains expected elements
|
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
|
|
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
|
|
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "DO $$") {
|
|
t.Errorf("Output missing DO block\nFull output:\n%s", output)
|
|
}
|
|
}
|