Files
relspecgo/pkg/writers/pgsql/writer_test.go
Hein c36b5ede2b
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m11s
CI / Build (push) Successful in -26m43s
CI / Lint (push) Successful in -26m34s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m20s
feat(writer): 🎉 Enhance primary key handling and add tests
* Implement checks for existing primary keys before adding new ones.
* Drop auto-generated primary keys if they exist.
* Add tests for primary key existence and column size specifiers.
* Improve type conversion handling for PostgreSQL compatibility.
2026-01-31 18:59:32 +02:00

441 lines
12 KiB
Go

package pgsql
import (
"bytes"
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
func TestWriteDatabase(t *testing.T) {
// Create a test database
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a test table
table := models.InitTable("users", "public")
table.Description = "User accounts table"
// Add columns
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.Description = "Primary key"
idCol.Default = "nextval('public.identity_users_id'::regclass)"
table.Columns["id"] = idCol
nameCol := models.InitColumn("name", "users", "public")
nameCol.Type = "text"
nameCol.Description = "User name"
table.Columns["name"] = nameCol
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "text"
emailCol.Description = "Email address"
table.Columns["email"] = emailCol
// Add primary key constraint
pkConstraint := &models.Constraint{
Name: "pk_users",
Type: models.PrimaryKeyConstraint,
Columns: []string{"id"},
}
table.Constraints["pk_users"] = pkConstraint
// Add unique index
uniqueEmailIndex := &models.Index{
Name: "uidx_users_email",
Unique: true,
Columns: []string{"email"},
}
table.Indexes["uidx_users_email"] = uniqueEmailIndex
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify output contains expected elements
expectedStrings := []string{
"CREATE TABLE",
"PRIMARY KEY",
"UNIQUE INDEX",
"COMMENT ON TABLE",
"COMMENT ON COLUMN",
}
for _, expected := range expectedStrings {
if !strings.Contains(output, expected) {
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
}
}
}
func TestWriteForeignKeys(t *testing.T) {
// Create a test database with two related tables
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create parent table (users)
usersTable := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
usersTable.Columns["id"] = idCol
// Create child table (posts)
postsTable := models.InitTable("posts", "public")
postIdCol := models.InitColumn("id", "posts", "public")
postIdCol.Type = "integer"
postsTable.Columns["id"] = postIdCol
userIdCol := models.InitColumn("user_id", "posts", "public")
userIdCol.Type = "integer"
postsTable.Columns["user_id"] = userIdCol
// Add foreign key constraint
fkConstraint := &models.Constraint{
Name: "fk_posts_users",
Type: models.ForeignKeyConstraint,
Columns: []string{"user_id"},
ReferencedTable: "users",
ReferencedSchema: "public",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
OnUpdate: "CASCADE",
}
postsTable.Constraints["fk_posts_users"] = fkConstraint
// Add relationship
relationship := &models.Relationship{
Name: "fk_posts_users",
FromTable: "posts",
FromSchema: "public",
ToTable: "users",
ToSchema: "public",
ForeignKey: "fk_posts_users",
}
postsTable.Relationships["fk_posts_users"] = relationship
schema.Tables = append(schema.Tables, usersTable, postsTable)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify foreign key is present
if !strings.Contains(output, "FOREIGN KEY") {
t.Errorf("Output missing FOREIGN KEY statement\nFull output:\n%s", output)
}
if !strings.Contains(output, "ON DELETE CASCADE") {
t.Errorf("Output missing ON DELETE CASCADE\nFull output:\n%s", output)
}
if !strings.Contains(output, "ON UPDATE CASCADE") {
t.Errorf("Output missing ON UPDATE CASCADE\nFull output:\n%s", output)
}
}
func TestWriteTable(t *testing.T) {
// Create a single table
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
table.Columns["id"] = idCol
nameCol := models.InitColumn("name", "products", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the table
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
output := buf.String()
// Verify output contains table creation
if !strings.Contains(output, "CREATE TABLE") {
t.Error("Output missing CREATE TABLE statement")
}
if !strings.Contains(output, "products") {
t.Error("Output missing table name 'products'")
}
}
func TestEscapeQuote(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"simple text", "simple text"},
{"text with 'quote'", "text with ''quote''"},
{"multiple 'quotes' here", "multiple ''quotes'' here"},
{"", ""},
}
for _, tt := range tests {
result := escapeQuote(tt.input)
if result != tt.expected {
t.Errorf("escapeQuote(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestIsIntegerType(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"integer", true},
{"INTEGER", true},
{"bigint", true},
{"smallint", true},
{"serial", true},
{"bigserial", true},
{"text", false},
{"varchar", false},
{"uuid", false},
}
for _, tt := range tests {
result := isIntegerType(tt.input)
if result != tt.expected {
t.Errorf("isIntegerType(%q) = %v, want %v", tt.input, result, tt.expected)
}
}
}
func TestTypeConversion(t *testing.T) {
// Test that invalid Go types are converted to valid PostgreSQL types
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a test table with Go types instead of SQL types
table := models.InitTable("test_types", "public")
// Add columns with Go types (invalid for PostgreSQL)
stringCol := models.InitColumn("name", "test_types", "public")
stringCol.Type = "string" // Should be converted to "text"
table.Columns["name"] = stringCol
int64Col := models.InitColumn("big_id", "test_types", "public")
int64Col.Type = "int64" // Should be converted to "bigint"
table.Columns["big_id"] = int64Col
int16Col := models.InitColumn("small_id", "test_types", "public")
int16Col.Type = "int16" // Should be converted to "smallint"
table.Columns["small_id"] = int16Col
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify that Go types were converted to PostgreSQL types
if strings.Contains(output, "string") {
t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output)
}
if strings.Contains(output, "int64") {
t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output)
}
if strings.Contains(output, "int16") {
t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output)
}
// Verify correct PostgreSQL types are present
if !strings.Contains(output, "text") {
t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output)
}
if !strings.Contains(output, "bigint") {
t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output)
}
if !strings.Contains(output, "smallint") {
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
}
}
func TestPrimaryKeyExistenceCheck(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
nameCol := models.InitColumn("name", "products", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify our naming convention is used
if !strings.Contains(output, "pk_public_products") {
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
}
// Verify it drops auto-generated primary keys
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
}
// Verify it checks for our specific named constraint before adding it
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
}
}
func TestColumnSizeSpecifiers(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("test_sizes", "public")
// Integer with invalid size specifier - should ignore size
integerCol := models.InitColumn("int_col", "test_sizes", "public")
integerCol.Type = "integer"
integerCol.Length = 32
table.Columns["int_col"] = integerCol
// Bigint with invalid size specifier - should ignore size
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
bigintCol.Type = "bigint"
bigintCol.Length = 64
table.Columns["bigint_col"] = bigintCol
// Smallint with invalid size specifier - should ignore size
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
smallintCol.Type = "smallint"
smallintCol.Length = 16
table.Columns["smallint_col"] = smallintCol
// Text with length - should convert to varchar
textCol := models.InitColumn("text_col", "test_sizes", "public")
textCol.Type = "text"
textCol.Length = 100
table.Columns["text_col"] = textCol
// Varchar with length - should keep varchar with length
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
varcharCol.Type = "varchar"
varcharCol.Length = 50
table.Columns["varchar_col"] = varcharCol
// Decimal with precision and scale - should keep them
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
decimalCol.Type = "decimal"
decimalCol.Precision = 19
decimalCol.Scale = 4
table.Columns["decimal_col"] = decimalCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify invalid size specifiers are NOT present
invalidPatterns := []string{
"integer(32)",
"bigint(64)",
"smallint(16)",
"text(100)",
}
for _, pattern := range invalidPatterns {
if strings.Contains(output, pattern) {
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
}
}
// Verify valid patterns ARE present
validPatterns := []string{
"integer", // without size
"bigint", // without size
"smallint", // without size
"varchar(100)", // text converted to varchar with length
"varchar(50)", // varchar with length
"decimal(19,4)", // decimal with precision and scale
}
for _, pattern := range validPatterns {
if !strings.Contains(output, pattern) {
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
}
}
}