All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m11s
CI / Build (push) Successful in -26m43s
CI / Lint (push) Successful in -26m34s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m20s
* Implement checks for existing primary keys before adding new ones. * Drop auto-generated primary keys if they exist. * Add tests for primary key existence and column size specifiers. * Improve type conversion handling for PostgreSQL compatibility.
441 lines
12 KiB
Go
441 lines
12 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"bytes"
|
|
"strings"
|
|
"testing"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
func TestWriteDatabase(t *testing.T) {
|
|
// Create a test database
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create a test table
|
|
table := models.InitTable("users", "public")
|
|
table.Description = "User accounts table"
|
|
|
|
// Add columns
|
|
idCol := models.InitColumn("id", "users", "public")
|
|
idCol.Type = "integer"
|
|
idCol.Description = "Primary key"
|
|
idCol.Default = "nextval('public.identity_users_id'::regclass)"
|
|
table.Columns["id"] = idCol
|
|
|
|
nameCol := models.InitColumn("name", "users", "public")
|
|
nameCol.Type = "text"
|
|
nameCol.Description = "User name"
|
|
table.Columns["name"] = nameCol
|
|
|
|
emailCol := models.InitColumn("email", "users", "public")
|
|
emailCol.Type = "text"
|
|
emailCol.Description = "Email address"
|
|
table.Columns["email"] = emailCol
|
|
|
|
// Add primary key constraint
|
|
pkConstraint := &models.Constraint{
|
|
Name: "pk_users",
|
|
Type: models.PrimaryKeyConstraint,
|
|
Columns: []string{"id"},
|
|
}
|
|
table.Constraints["pk_users"] = pkConstraint
|
|
|
|
// Add unique index
|
|
uniqueEmailIndex := &models.Index{
|
|
Name: "uidx_users_email",
|
|
Unique: true,
|
|
Columns: []string{"email"},
|
|
}
|
|
table.Indexes["uidx_users_email"] = uniqueEmailIndex
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify output contains expected elements
|
|
expectedStrings := []string{
|
|
"CREATE TABLE",
|
|
"PRIMARY KEY",
|
|
"UNIQUE INDEX",
|
|
"COMMENT ON TABLE",
|
|
"COMMENT ON COLUMN",
|
|
}
|
|
|
|
for _, expected := range expectedStrings {
|
|
if !strings.Contains(output, expected) {
|
|
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteForeignKeys(t *testing.T) {
|
|
// Create a test database with two related tables
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create parent table (users)
|
|
usersTable := models.InitTable("users", "public")
|
|
idCol := models.InitColumn("id", "users", "public")
|
|
idCol.Type = "integer"
|
|
usersTable.Columns["id"] = idCol
|
|
|
|
// Create child table (posts)
|
|
postsTable := models.InitTable("posts", "public")
|
|
postIdCol := models.InitColumn("id", "posts", "public")
|
|
postIdCol.Type = "integer"
|
|
postsTable.Columns["id"] = postIdCol
|
|
|
|
userIdCol := models.InitColumn("user_id", "posts", "public")
|
|
userIdCol.Type = "integer"
|
|
postsTable.Columns["user_id"] = userIdCol
|
|
|
|
// Add foreign key constraint
|
|
fkConstraint := &models.Constraint{
|
|
Name: "fk_posts_users",
|
|
Type: models.ForeignKeyConstraint,
|
|
Columns: []string{"user_id"},
|
|
ReferencedTable: "users",
|
|
ReferencedSchema: "public",
|
|
ReferencedColumns: []string{"id"},
|
|
OnDelete: "CASCADE",
|
|
OnUpdate: "CASCADE",
|
|
}
|
|
postsTable.Constraints["fk_posts_users"] = fkConstraint
|
|
|
|
// Add relationship
|
|
relationship := &models.Relationship{
|
|
Name: "fk_posts_users",
|
|
FromTable: "posts",
|
|
FromSchema: "public",
|
|
ToTable: "users",
|
|
ToSchema: "public",
|
|
ForeignKey: "fk_posts_users",
|
|
}
|
|
postsTable.Relationships["fk_posts_users"] = relationship
|
|
|
|
schema.Tables = append(schema.Tables, usersTable, postsTable)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify foreign key is present
|
|
if !strings.Contains(output, "FOREIGN KEY") {
|
|
t.Errorf("Output missing FOREIGN KEY statement\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ON DELETE CASCADE") {
|
|
t.Errorf("Output missing ON DELETE CASCADE\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "ON UPDATE CASCADE") {
|
|
t.Errorf("Output missing ON UPDATE CASCADE\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestWriteTable(t *testing.T) {
|
|
// Create a single table
|
|
table := models.InitTable("products", "public")
|
|
|
|
idCol := models.InitColumn("id", "products", "public")
|
|
idCol.Type = "integer"
|
|
table.Columns["id"] = idCol
|
|
|
|
nameCol := models.InitColumn("name", "products", "public")
|
|
nameCol.Type = "text"
|
|
table.Columns["name"] = nameCol
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the table
|
|
err := writer.WriteTable(table)
|
|
if err != nil {
|
|
t.Fatalf("WriteTable failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Verify output contains table creation
|
|
if !strings.Contains(output, "CREATE TABLE") {
|
|
t.Error("Output missing CREATE TABLE statement")
|
|
}
|
|
if !strings.Contains(output, "products") {
|
|
t.Error("Output missing table name 'products'")
|
|
}
|
|
}
|
|
|
|
func TestEscapeQuote(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"simple text", "simple text"},
|
|
{"text with 'quote'", "text with ''quote''"},
|
|
{"multiple 'quotes' here", "multiple ''quotes'' here"},
|
|
{"", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := escapeQuote(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("escapeQuote(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIsIntegerType(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected bool
|
|
}{
|
|
{"integer", true},
|
|
{"INTEGER", true},
|
|
{"bigint", true},
|
|
{"smallint", true},
|
|
{"serial", true},
|
|
{"bigserial", true},
|
|
{"text", false},
|
|
{"varchar", false},
|
|
{"uuid", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := isIntegerType(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("isIntegerType(%q) = %v, want %v", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTypeConversion(t *testing.T) {
|
|
// Test that invalid Go types are converted to valid PostgreSQL types
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Create a test table with Go types instead of SQL types
|
|
table := models.InitTable("test_types", "public")
|
|
|
|
// Add columns with Go types (invalid for PostgreSQL)
|
|
stringCol := models.InitColumn("name", "test_types", "public")
|
|
stringCol.Type = "string" // Should be converted to "text"
|
|
table.Columns["name"] = stringCol
|
|
|
|
int64Col := models.InitColumn("big_id", "test_types", "public")
|
|
int64Col.Type = "int64" // Should be converted to "bigint"
|
|
table.Columns["big_id"] = int64Col
|
|
|
|
int16Col := models.InitColumn("small_id", "test_types", "public")
|
|
int16Col.Type = "int16" // Should be converted to "smallint"
|
|
table.Columns["small_id"] = int16Col
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with output to buffer
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
// Write the database
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
|
|
// Print output for debugging
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify that Go types were converted to PostgreSQL types
|
|
if strings.Contains(output, "string") {
|
|
t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output)
|
|
}
|
|
if strings.Contains(output, "int64") {
|
|
t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output)
|
|
}
|
|
if strings.Contains(output, "int16") {
|
|
t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output)
|
|
}
|
|
|
|
// Verify correct PostgreSQL types are present
|
|
if !strings.Contains(output, "text") {
|
|
t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "bigint") {
|
|
t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output)
|
|
}
|
|
if !strings.Contains(output, "smallint") {
|
|
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestPrimaryKeyExistenceCheck(t *testing.T) {
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
table := models.InitTable("products", "public")
|
|
|
|
idCol := models.InitColumn("id", "products", "public")
|
|
idCol.Type = "integer"
|
|
idCol.IsPrimaryKey = true
|
|
table.Columns["id"] = idCol
|
|
|
|
nameCol := models.InitColumn("name", "products", "public")
|
|
nameCol.Type = "text"
|
|
table.Columns["name"] = nameCol
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify our naming convention is used
|
|
if !strings.Contains(output, "pk_public_products") {
|
|
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
|
|
}
|
|
|
|
// Verify it drops auto-generated primary keys
|
|
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
|
|
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
|
|
}
|
|
|
|
// Verify it checks for our specific named constraint before adding it
|
|
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
|
|
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
|
|
}
|
|
}
|
|
|
|
func TestColumnSizeSpecifiers(t *testing.T) {
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
table := models.InitTable("test_sizes", "public")
|
|
|
|
// Integer with invalid size specifier - should ignore size
|
|
integerCol := models.InitColumn("int_col", "test_sizes", "public")
|
|
integerCol.Type = "integer"
|
|
integerCol.Length = 32
|
|
table.Columns["int_col"] = integerCol
|
|
|
|
// Bigint with invalid size specifier - should ignore size
|
|
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
|
|
bigintCol.Type = "bigint"
|
|
bigintCol.Length = 64
|
|
table.Columns["bigint_col"] = bigintCol
|
|
|
|
// Smallint with invalid size specifier - should ignore size
|
|
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
|
|
smallintCol.Type = "smallint"
|
|
smallintCol.Length = 16
|
|
table.Columns["smallint_col"] = smallintCol
|
|
|
|
// Text with length - should convert to varchar
|
|
textCol := models.InitColumn("text_col", "test_sizes", "public")
|
|
textCol.Type = "text"
|
|
textCol.Length = 100
|
|
table.Columns["text_col"] = textCol
|
|
|
|
// Varchar with length - should keep varchar with length
|
|
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
|
|
varcharCol.Type = "varchar"
|
|
varcharCol.Length = 50
|
|
table.Columns["varchar_col"] = varcharCol
|
|
|
|
// Decimal with precision and scale - should keep them
|
|
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
|
|
decimalCol.Type = "decimal"
|
|
decimalCol.Precision = 19
|
|
decimalCol.Scale = 4
|
|
table.Columns["decimal_col"] = decimalCol
|
|
|
|
schema.Tables = append(schema.Tables, table)
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
var buf bytes.Buffer
|
|
options := &writers.WriterOptions{}
|
|
writer := NewWriter(options)
|
|
writer.writer = &buf
|
|
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
output := buf.String()
|
|
t.Logf("Generated SQL:\n%s", output)
|
|
|
|
// Verify invalid size specifiers are NOT present
|
|
invalidPatterns := []string{
|
|
"integer(32)",
|
|
"bigint(64)",
|
|
"smallint(16)",
|
|
"text(100)",
|
|
}
|
|
for _, pattern := range invalidPatterns {
|
|
if strings.Contains(output, pattern) {
|
|
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
|
|
}
|
|
}
|
|
|
|
// Verify valid patterns ARE present
|
|
validPatterns := []string{
|
|
"integer", // without size
|
|
"bigint", // without size
|
|
"smallint", // without size
|
|
"varchar(100)", // text converted to varchar with length
|
|
"varchar(50)", // varchar with length
|
|
"decimal(19,4)", // decimal with precision and scale
|
|
}
|
|
for _, pattern := range validPatterns {
|
|
if !strings.Contains(output, pattern) {
|
|
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
|
|
}
|
|
}
|
|
}
|