Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d9770b430 | |||
| f2d500f98d | |||
| 2ec9991324 | |||
| a3e45c206d | |||
| 165623bb1d | |||
| 3c20c3c5d9 |
714
pkg/commontypes/commontypes_test.go
Normal file
714
pkg/commontypes/commontypes_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package commontypes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractBaseType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
want string
|
||||
}{
|
||||
{"varchar with length", "varchar(100)", "varchar"},
|
||||
{"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"},
|
||||
{"numeric with precision", "numeric(10,2)", "numeric"},
|
||||
{"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"},
|
||||
{"decimal with precision", "decimal(15,3)", "decimal"},
|
||||
{"char with length", "char(50)", "char"},
|
||||
{"simple integer", "integer", "integer"},
|
||||
{"simple text", "text", "text"},
|
||||
{"bigint", "bigint", "bigint"},
|
||||
{"With spaces", " varchar(100) ", "varchar"},
|
||||
{"No parentheses", "boolean", "boolean"},
|
||||
{"Empty string", "", ""},
|
||||
{"Mixed case", "VarChar(100)", "varchar"},
|
||||
{"timestamp with time zone", "timestamp(6) with time zone", "timestamp"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ExtractBaseType(tt.sqlType)
|
||||
if got != tt.want {
|
||||
t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeType(t *testing.T) {
|
||||
// NormalizeType is an alias for ExtractBaseType, test that they behave the same
|
||||
testCases := []string{
|
||||
"varchar(100)",
|
||||
"numeric(10,2)",
|
||||
"integer",
|
||||
"text",
|
||||
" VARCHAR(255) ",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
extracted := ExtractBaseType(tc)
|
||||
normalized := NormalizeType(tc)
|
||||
if extracted != normalized {
|
||||
t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q",
|
||||
tc, extracted, tc, normalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToGo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int32"},
|
||||
{"bigint nullable", "bigint", true, "int64"},
|
||||
{"smallint nullable", "smallint", true, "int16"},
|
||||
{"serial nullable", "serial", true, "int32"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "*int32"},
|
||||
{"bigint not nullable", "bigint", false, "*int64"},
|
||||
{"smallint not nullable", "smallint", false, "*int16"},
|
||||
|
||||
// String types (nullable)
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar with length nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// String types (not nullable)
|
||||
{"text not nullable", "text", false, "*string"},
|
||||
{"varchar not nullable", "varchar", false, "*string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "*bool"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float32"},
|
||||
{"double precision nullable", "double precision", true, "float64"},
|
||||
{"real not nullable", "real", false, "*float32"},
|
||||
{"double precision not nullable", "double precision", false, "*float64"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "time.Time"},
|
||||
{"date nullable", "date", true, "time.Time"},
|
||||
{"timestamp not nullable", "timestamp", false, "*time.Time"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "[]byte"},
|
||||
{"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
{"uuid not nullable", "uuid", false, "*string"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "string"},
|
||||
{"jsonb nullable", "jsonb", true, "string"},
|
||||
|
||||
// Array
|
||||
{"array nullable", "array", true, "[]string"},
|
||||
{"array not nullable", "array", false, "[]string"}, // Slices don't get pointer
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "interface{}"},
|
||||
{"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToGo(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToTypeScript(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer nullable", "integer", true, "number"},
|
||||
{"integer not nullable", "integer", false, "number | null"},
|
||||
{"bigint nullable", "bigint", true, "number"},
|
||||
{"bigint not nullable", "bigint", false, "number | null"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "string | null"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "boolean"},
|
||||
{"boolean not nullable", "boolean", false, "boolean | null"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "number"},
|
||||
{"double precision nullable", "double precision", true, "number"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "Date"},
|
||||
{"date nullable", "date", true, "Date"},
|
||||
{"timestamp not nullable", "timestamp", false, "Date | null"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "Buffer"},
|
||||
{"bytea not nullable", "bytea", false, "Buffer | null"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "any"},
|
||||
{"jsonb nullable", "jsonb", true, "any"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "any"},
|
||||
{"unknown type not nullable", "unknowntype", false, "any | null"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToTypeScript(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToPython(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer", "integer", "int"},
|
||||
{"bigint", "bigint", "int"},
|
||||
{"smallint", "smallint", "int"},
|
||||
|
||||
// String types
|
||||
{"text", "text", "str"},
|
||||
{"varchar", "varchar", "str"},
|
||||
{"varchar(100)", "varchar(100)", "str"},
|
||||
|
||||
// Boolean
|
||||
{"boolean", "boolean", "bool"},
|
||||
|
||||
// Float types
|
||||
{"real", "real", "float"},
|
||||
{"double precision", "double precision", "float"},
|
||||
{"numeric", "numeric", "Decimal"},
|
||||
{"decimal", "decimal", "Decimal"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp", "timestamp", "datetime"},
|
||||
{"date", "date", "date"},
|
||||
{"time", "time", "time"},
|
||||
|
||||
// Binary
|
||||
{"bytea", "bytea", "bytes"},
|
||||
|
||||
// JSON
|
||||
{"json", "json", "dict"},
|
||||
{"jsonb", "jsonb", "dict"},
|
||||
|
||||
// UUID
|
||||
{"uuid", "uuid", "UUID"},
|
||||
|
||||
// Array
|
||||
{"array", "array", "list"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type", "unknowntype", "Any"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToPython(tt.sqlType)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToCSharp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int"},
|
||||
{"bigint nullable", "bigint", true, "long"},
|
||||
{"smallint nullable", "smallint", true, "short"},
|
||||
|
||||
// Integer types (not nullable - value types get ?)
|
||||
{"integer not nullable", "integer", false, "int?"},
|
||||
{"bigint not nullable", "bigint", false, "long?"},
|
||||
{"smallint not nullable", "smallint", false, "short?"},
|
||||
|
||||
// String types (reference types, no ? needed)
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "bool?"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float"},
|
||||
{"double precision nullable", "double precision", true, "double"},
|
||||
{"decimal nullable", "decimal", true, "decimal"},
|
||||
{"real not nullable", "real", false, "float?"},
|
||||
{"double precision not nullable", "double precision", false, "double?"},
|
||||
{"decimal not nullable", "decimal", false, "decimal?"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "DateTime"},
|
||||
{"date nullable", "date", true, "DateTime"},
|
||||
{"timestamptz nullable", "timestamptz", true, "DateTimeOffset"},
|
||||
{"timestamp not nullable", "timestamp", false, "DateTime?"},
|
||||
{"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"},
|
||||
|
||||
// Binary (array type, no ?)
|
||||
{"bytea nullable", "bytea", true, "byte[]"},
|
||||
{"bytea not nullable", "bytea", false, "byte[]"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "Guid"},
|
||||
{"uuid not nullable", "uuid", false, "Guid?"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "string"},
|
||||
|
||||
// Unknown types (object is reference type)
|
||||
{"unknown type nullable", "unknowntype", true, "object"},
|
||||
{"unknown type not nullable", "unknowntype", false, "object"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToCSharp(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsTimeImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
goType string
|
||||
want bool
|
||||
}{
|
||||
{"time.Time type", "time.Time", true},
|
||||
{"pointer to time.Time", "*time.Time", true},
|
||||
{"int32 type", "int32", false},
|
||||
{"string type", "string", false},
|
||||
{"bool type", "bool", false},
|
||||
{"[]byte type", "[]byte", false},
|
||||
{"interface{}", "interface{}", false},
|
||||
{"empty string", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NeedsTimeImport(tt.goType)
|
||||
if got != tt.want {
|
||||
t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoTypeMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int32",
|
||||
"bigint": "int64",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "float64",
|
||||
"bytea": "[]byte",
|
||||
"timestamp": "time.Time",
|
||||
"uuid": "string",
|
||||
"json": "string",
|
||||
}
|
||||
|
||||
for sqlType, expectedGoType := range expectedMappings {
|
||||
if goType, ok := GoTypeMap[sqlType]; !ok {
|
||||
t.Errorf("GoTypeMap missing entry for %q", sqlType)
|
||||
} else if goType != expectedGoType {
|
||||
t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(GoTypeMap) == 0 {
|
||||
t.Error("GoTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeScriptTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "number",
|
||||
"bigint": "number",
|
||||
"text": "string",
|
||||
"boolean": "boolean",
|
||||
"double precision": "number",
|
||||
"bytea": "Buffer",
|
||||
"timestamp": "Date",
|
||||
"uuid": "string",
|
||||
"json": "any",
|
||||
}
|
||||
|
||||
for sqlType, expectedTSType := range expectedMappings {
|
||||
if tsType, ok := TypeScriptTypeMap[sqlType]; !ok {
|
||||
t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType)
|
||||
} else if tsType != expectedTSType {
|
||||
t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(TypeScriptTypeMap) == 0 {
|
||||
t.Error("TypeScriptTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPythonTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int",
|
||||
"text": "str",
|
||||
"boolean": "bool",
|
||||
"real": "float",
|
||||
"numeric": "Decimal",
|
||||
"bytea": "bytes",
|
||||
"date": "date",
|
||||
"uuid": "UUID",
|
||||
"json": "dict",
|
||||
}
|
||||
|
||||
for sqlType, expectedPyType := range expectedMappings {
|
||||
if pyType, ok := PythonTypeMap[sqlType]; !ok {
|
||||
t.Errorf("PythonTypeMap missing entry for %q", sqlType)
|
||||
} else if pyType != expectedPyType {
|
||||
t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(PythonTypeMap) == 0 {
|
||||
t.Error("PythonTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSharpTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "long",
|
||||
"smallint": "short",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "double",
|
||||
"decimal": "decimal",
|
||||
"bytea": "byte[]",
|
||||
"timestamp": "DateTime",
|
||||
"uuid": "Guid",
|
||||
"json": "string",
|
||||
}
|
||||
|
||||
for sqlType, expectedCSType := range expectedMappings {
|
||||
if csType, ok := CSharpTypeMap[sqlType]; !ok {
|
||||
t.Errorf("CSharpTypeMap missing entry for %q", sqlType)
|
||||
} else if csType != expectedCSType {
|
||||
t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(CSharpTypeMap) == 0 {
|
||||
t.Error("CSharpTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToJava(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer nullable", "integer", true, "Integer"},
|
||||
{"integer not nullable", "integer", false, "Integer"},
|
||||
{"bigint nullable", "bigint", true, "Long"},
|
||||
{"smallint nullable", "smallint", true, "Short"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "String"},
|
||||
{"varchar nullable", "varchar", true, "String"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "Boolean"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "Float"},
|
||||
{"double precision nullable", "double precision", true, "Double"},
|
||||
{"numeric nullable", "numeric", true, "BigDecimal"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "Timestamp"},
|
||||
{"date nullable", "date", true, "Date"},
|
||||
{"time nullable", "time", true, "Time"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "byte[]"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "UUID"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "String"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "Object"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToJava(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToPhp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "int"},
|
||||
{"bigint nullable", "bigint", true, "int"},
|
||||
{"smallint nullable", "smallint", true, "int"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "?int"},
|
||||
{"bigint not nullable", "bigint", false, "?int"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "string"},
|
||||
{"text not nullable", "text", false, "?string"},
|
||||
{"varchar nullable", "varchar", true, "string"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "?bool"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "float"},
|
||||
{"double precision nullable", "double precision", true, "float"},
|
||||
{"real not nullable", "real", false, "?float"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "\\DateTime"},
|
||||
{"date nullable", "date", true, "\\DateTime"},
|
||||
{"timestamp not nullable", "timestamp", false, "?\\DateTime"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "string"},
|
||||
{"bytea not nullable", "bytea", false, "?string"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "array"},
|
||||
{"json not nullable", "json", false, "?array"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "string"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "mixed"},
|
||||
{"unknown type not nullable", "unknowntype", false, "mixed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToPhp(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLToRust(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlType string
|
||||
nullable bool
|
||||
want string
|
||||
}{
|
||||
// Integer types (nullable)
|
||||
{"integer nullable", "integer", true, "i32"},
|
||||
{"bigint nullable", "bigint", true, "i64"},
|
||||
{"smallint nullable", "smallint", true, "i16"},
|
||||
|
||||
// Integer types (not nullable)
|
||||
{"integer not nullable", "integer", false, "Option<i32>"},
|
||||
{"bigint not nullable", "bigint", false, "Option<i64>"},
|
||||
{"smallint not nullable", "smallint", false, "Option<i16>"},
|
||||
|
||||
// String types
|
||||
{"text nullable", "text", true, "String"},
|
||||
{"text not nullable", "text", false, "Option<String>"},
|
||||
{"varchar nullable", "varchar", true, "String"},
|
||||
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||
|
||||
// Boolean
|
||||
{"boolean nullable", "boolean", true, "bool"},
|
||||
{"boolean not nullable", "boolean", false, "Option<bool>"},
|
||||
|
||||
// Float types
|
||||
{"real nullable", "real", true, "f32"},
|
||||
{"double precision nullable", "double precision", true, "f64"},
|
||||
{"real not nullable", "real", false, "Option<f32>"},
|
||||
{"double precision not nullable", "double precision", false, "Option<f64>"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp nullable", "timestamp", true, "NaiveDateTime"},
|
||||
{"timestamptz nullable", "timestamptz", true, "DateTime<Utc>"},
|
||||
{"date nullable", "date", true, "NaiveDate"},
|
||||
{"time nullable", "time", true, "NaiveTime"},
|
||||
{"timestamp not nullable", "timestamp", false, "Option<NaiveDateTime>"},
|
||||
|
||||
// Binary
|
||||
{"bytea nullable", "bytea", true, "Vec<u8>"},
|
||||
{"bytea not nullable", "bytea", false, "Option<Vec<u8>>"},
|
||||
|
||||
// JSON
|
||||
{"json nullable", "json", true, "serde_json::Value"},
|
||||
{"json not nullable", "json", false, "Option<serde_json::Value>"},
|
||||
|
||||
// UUID
|
||||
{"uuid nullable", "uuid", true, "String"},
|
||||
|
||||
// Unknown types
|
||||
{"unknown type nullable", "unknowntype", true, "String"},
|
||||
{"unknown type not nullable", "unknowntype", false, "Option<String>"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SQLToRust(tt.sqlType, tt.nullable)
|
||||
if got != tt.want {
|
||||
t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJavaTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "Integer",
|
||||
"bigint": "Long",
|
||||
"smallint": "Short",
|
||||
"text": "String",
|
||||
"boolean": "Boolean",
|
||||
"double precision": "Double",
|
||||
"numeric": "BigDecimal",
|
||||
"bytea": "byte[]",
|
||||
"timestamp": "Timestamp",
|
||||
"uuid": "UUID",
|
||||
"date": "Date",
|
||||
}
|
||||
|
||||
for sqlType, expectedJavaType := range expectedMappings {
|
||||
if javaType, ok := JavaTypeMap[sqlType]; !ok {
|
||||
t.Errorf("JavaTypeMap missing entry for %q", sqlType)
|
||||
} else if javaType != expectedJavaType {
|
||||
t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(JavaTypeMap) == 0 {
|
||||
t.Error("JavaTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPHPTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"double precision": "float",
|
||||
"bytea": "string",
|
||||
"timestamp": "\\DateTime",
|
||||
"uuid": "string",
|
||||
"json": "array",
|
||||
}
|
||||
|
||||
for sqlType, expectedPHPType := range expectedMappings {
|
||||
if phpType, ok := PHPTypeMap[sqlType]; !ok {
|
||||
t.Errorf("PHPTypeMap missing entry for %q", sqlType)
|
||||
} else if phpType != expectedPHPType {
|
||||
t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(PHPTypeMap) == 0 {
|
||||
t.Error("PHPTypeMap is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRustTypeMap(t *testing.T) {
|
||||
expectedMappings := map[string]string{
|
||||
"integer": "i32",
|
||||
"bigint": "i64",
|
||||
"smallint": "i16",
|
||||
"text": "String",
|
||||
"boolean": "bool",
|
||||
"double precision": "f64",
|
||||
"real": "f32",
|
||||
"bytea": "Vec<u8>",
|
||||
"timestamp": "NaiveDateTime",
|
||||
"timestamptz": "DateTime<Utc>",
|
||||
"date": "NaiveDate",
|
||||
"json": "serde_json::Value",
|
||||
}
|
||||
|
||||
for sqlType, expectedRustType := range expectedMappings {
|
||||
if rustType, ok := RustTypeMap[sqlType]; !ok {
|
||||
t.Errorf("RustTypeMap missing entry for %q", sqlType)
|
||||
} else if rustType != expectedRustType {
|
||||
t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType)
|
||||
}
|
||||
}
|
||||
|
||||
if len(RustTypeMap) == 0 {
|
||||
t.Error("RustTypeMap is empty")
|
||||
}
|
||||
}
|
||||
558
pkg/diff/diff_test.go
Normal file
558
pkg/diff/diff_test.go
Normal file
@@ -0,0 +1,558 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestCompareDatabases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source *models.Database
|
||||
target *models.Database
|
||||
want func(*DiffResult) bool
|
||||
}{
|
||||
{
|
||||
name: "identical databases",
|
||||
source: &models.Database{
|
||||
Name: "source",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
target: &models.Database{
|
||||
Name: "target",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
want: func(r *DiffResult) bool {
|
||||
return r.Source == "source" && r.Target == "target" &&
|
||||
len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different schemas",
|
||||
source: &models.Database{
|
||||
Name: "source",
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
target: &models.Database{
|
||||
Name: "target",
|
||||
Schemas: []*models.Schema{},
|
||||
},
|
||||
want: func(r *DiffResult) bool {
|
||||
return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CompareDatabases(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("CompareDatabases() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Column
|
||||
target map[string]*models.Column
|
||||
want func(*ColumnDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical columns",
|
||||
source: map[string]*models.Column{},
|
||||
target: map[string]*models.Column{},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing column",
|
||||
source: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
target: map[string]*models.Column{},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra column",
|
||||
source: map[string]*models.Column{},
|
||||
target: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column type",
|
||||
source: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint"},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "id" &&
|
||||
d.Modified[0].Changes["type"] != nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column nullable",
|
||||
source: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "text", NotNull: true},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "text", NotNull: false},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified column length",
|
||||
source: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "varchar", Length: 100},
|
||||
},
|
||||
target: map[string]*models.Column{
|
||||
"name": {Name: "name", Type: "varchar", Length: 255},
|
||||
},
|
||||
want: func(d *ColumnDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareColumns(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareColumns() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareColumnDetails(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source *models.Column
|
||||
target *models.Column
|
||||
want int // number of changes
|
||||
}{
|
||||
{
|
||||
name: "identical columns",
|
||||
source: &models.Column{Name: "id", Type: "integer"},
|
||||
target: &models.Column{Name: "id", Type: "integer"},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "type change",
|
||||
source: &models.Column{Name: "id", Type: "integer"},
|
||||
target: &models.Column{Name: "id", Type: "bigint"},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "length change",
|
||||
source: &models.Column{Name: "name", Type: "varchar", Length: 100},
|
||||
target: &models.Column{Name: "name", Type: "varchar", Length: 255},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "precision change",
|
||||
source: &models.Column{Name: "price", Type: "numeric", Precision: 10},
|
||||
target: &models.Column{Name: "price", Type: "numeric", Precision: 12},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "scale change",
|
||||
source: &models.Column{Name: "price", Type: "numeric", Scale: 2},
|
||||
target: &models.Column{Name: "price", Type: "numeric", Scale: 4},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "not null change",
|
||||
source: &models.Column{Name: "name", Type: "text", NotNull: true},
|
||||
target: &models.Column{Name: "name", Type: "text", NotNull: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "auto increment change",
|
||||
source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true},
|
||||
target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "primary key change",
|
||||
source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true},
|
||||
target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple changes",
|
||||
source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true},
|
||||
target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false},
|
||||
want: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareColumnDetails(tt.source, tt.target)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareIndexes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Index
|
||||
target map[string]*models.Index
|
||||
want func(*IndexDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical indexes",
|
||||
source: map[string]*models.Index{},
|
||||
target: map[string]*models.Index{},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing index",
|
||||
source: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||
},
|
||||
target: map[string]*models.Index{},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra index",
|
||||
source: map[string]*models.Index{},
|
||||
target: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||
},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified index uniqueness",
|
||||
source: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false},
|
||||
},
|
||||
target: map[string]*models.Index{
|
||||
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true},
|
||||
},
|
||||
want: func(d *IndexDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareIndexes(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareIndexes() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareConstraints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Constraint
|
||||
target map[string]*models.Constraint
|
||||
want func(*ConstraintDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical constraints",
|
||||
source: map[string]*models.Constraint{},
|
||||
target: map[string]*models.Constraint{},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing constraint",
|
||||
source: map[string]*models.Constraint{
|
||||
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||
},
|
||||
target: map[string]*models.Constraint{},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra constraint",
|
||||
source: map[string]*models.Constraint{},
|
||||
target: map[string]*models.Constraint{
|
||||
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||
},
|
||||
want: func(d *ConstraintDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareConstraints(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareConstraints() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareRelationships(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source map[string]*models.Relationship
|
||||
target map[string]*models.Relationship
|
||||
want func(*RelationshipDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical relationships",
|
||||
source: map[string]*models.Relationship{},
|
||||
target: map[string]*models.Relationship{},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing relationship",
|
||||
source: map[string]*models.Relationship{
|
||||
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||
},
|
||||
target: map[string]*models.Relationship{},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra relationship",
|
||||
source: map[string]*models.Relationship{},
|
||||
target: map[string]*models.Relationship{
|
||||
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||
},
|
||||
want: func(d *RelationshipDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareRelationships(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareRelationships() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source []*models.Table
|
||||
target []*models.Table
|
||||
want func(*TableDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical tables",
|
||||
source: []*models.Table{},
|
||||
target: []*models.Table{},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing table",
|
||||
source: []*models.Table{
|
||||
{Name: "users", Schema: "public"},
|
||||
},
|
||||
target: []*models.Table{},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "users"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra table",
|
||||
source: []*models.Table{},
|
||||
target: []*models.Table{
|
||||
{Name: "users", Schema: "public"},
|
||||
},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "users"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "modified table",
|
||||
source: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "integer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
target: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(d *TableDiff) bool {
|
||||
return len(d.Modified) == 1 && d.Modified[0].Name == "users"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareTables(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareTables() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareSchemas(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
source []*models.Schema
|
||||
target []*models.Schema
|
||||
want func(*SchemaDiff) bool
|
||||
}{
|
||||
{
|
||||
name: "identical schemas",
|
||||
source: []*models.Schema{},
|
||||
target: []*models.Schema{},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing schema",
|
||||
source: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
target: []*models.Schema{},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Missing) == 1 && d.Missing[0].Name == "public"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra schema",
|
||||
source: []*models.Schema{},
|
||||
target: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
want: func(d *SchemaDiff) bool {
|
||||
return len(d.Extra) == 1 && d.Extra[0].Name == "public"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := compareSchemas(tt.source, tt.target)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("compareSchemas() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v interface{}
|
||||
want bool
|
||||
}{
|
||||
{"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true},
|
||||
{"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false},
|
||||
{"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false},
|
||||
{"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true},
|
||||
{"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false},
|
||||
{"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true},
|
||||
{"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false},
|
||||
{"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true},
|
||||
{"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isEmpty(tt.v)
|
||||
if got != tt.want {
|
||||
t.Errorf("isEmpty() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeSummary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
want func(*Summary) bool
|
||||
}{
|
||||
{
|
||||
name: "empty diff",
|
||||
result: &DiffResult{
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
want: func(s *Summary) bool {
|
||||
return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "schemas with differences",
|
||||
result: &DiffResult{
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}},
|
||||
Modified: []*SchemaChange{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(s *Summary) bool {
|
||||
return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ComputeSummary(tt.result)
|
||||
if !tt.want(got) {
|
||||
t.Errorf("ComputeSummary() result doesn't match expectations")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
440
pkg/diff/formatters_test.go
Normal file
440
pkg/diff/formatters_test.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestFormatDiff(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source_db",
|
||||
Target: "target_db",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
format OutputFormat
|
||||
wantErr bool
|
||||
}{
|
||||
{"summary format", FormatSummary, false},
|
||||
{"json format", FormatJSON, false},
|
||||
{"html format", FormatHTML, false},
|
||||
{"invalid format", OutputFormat("invalid"), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := FormatDiff(result, tt.format, &buf)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && buf.Len() == 0 {
|
||||
t.Error("FormatDiff() produced empty output")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
wantStr []string // strings that should appear in output
|
||||
}{
|
||||
{
|
||||
name: "no differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"source", "target", "No differences found"},
|
||||
},
|
||||
{
|
||||
name: "with schema differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{{Name: "schema2"}},
|
||||
Modified: []*SchemaChange{
|
||||
{Name: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||
},
|
||||
{
|
||||
name: "with table differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Missing: []*models.Table{{Name: "users"}},
|
||||
Extra: []*models.Table{{Name: "posts"}},
|
||||
Modified: []*TableChange{
|
||||
{Name: "comments", Schema: "public"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(tt.result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range tt.wantStr {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatJSON(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "schema1"}},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatJSON(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatJSON() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if output is valid JSON
|
||||
var decoded DiffResult
|
||||
if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil {
|
||||
t.Errorf("formatJSON() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check basic structure
|
||||
if decoded.Source != "source" {
|
||||
t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source")
|
||||
}
|
||||
if decoded.Target != "target" {
|
||||
t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target")
|
||||
}
|
||||
if len(decoded.Schemas.Missing) != 1 {
|
||||
t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatHTML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *DiffResult
|
||||
wantStr []string // HTML elements/content that should appear
|
||||
}{
|
||||
{
|
||||
name: "basic HTML structure",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{},
|
||||
Extra: []*models.Schema{},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"<!DOCTYPE html>",
|
||||
"<title>Database Diff Report</title>",
|
||||
"source",
|
||||
"target",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with schema differences",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "missing_schema"}},
|
||||
Extra: []*models.Schema{{Name: "extra_schema"}},
|
||||
Modified: []*SchemaChange{},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"<!DOCTYPE html>",
|
||||
"missing_schema",
|
||||
"extra_schema",
|
||||
"MISSING",
|
||||
"EXTRA",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with table modifications",
|
||||
result: &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: &ColumnDiff{
|
||||
Missing: []*models.Column{{Name: "email", Type: "text"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStr: []string{
|
||||
"public",
|
||||
"users",
|
||||
"email",
|
||||
"text",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := formatHTML(tt.result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatHTML() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range tt.wantStr {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatHTML() output doesn't contain %q", want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithColumns(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: &ColumnDiff{
|
||||
Missing: []*models.Column{{Name: "email"}},
|
||||
Extra: []*models.Column{{Name: "phone"}, {Name: "address"}},
|
||||
Modified: []*ColumnChange{
|
||||
{Name: "name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
wantStrings := []string{
|
||||
"Columns:",
|
||||
"Missing: 1",
|
||||
"Extra: 2",
|
||||
"Modified: 1",
|
||||
}
|
||||
|
||||
for _, want := range wantStrings {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithIndexes(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Indexes: &IndexDiff{
|
||||
Missing: []*models.Index{{Name: "idx_email"}},
|
||||
Extra: []*models.Index{{Name: "idx_phone"}},
|
||||
Modified: []*IndexChange{{Name: "idx_name"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Indexes:") {
|
||||
t.Error("formatSummary() output doesn't contain Indexes section")
|
||||
}
|
||||
if !strings.Contains(output, "Missing: 1") {
|
||||
t.Error("formatSummary() output doesn't contain correct missing count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryWithConstraints(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Modified: []*SchemaChange{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: &TableDiff{
|
||||
Modified: []*TableChange{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Constraints: &ConstraintDiff{
|
||||
Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}},
|
||||
Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatSummary(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatSummary() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Constraints:") {
|
||||
t.Error("formatSummary() output doesn't contain Constraints section")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatJSONIndentation(t *testing.T) {
|
||||
result := &DiffResult{
|
||||
Source: "source",
|
||||
Target: "target",
|
||||
Schemas: &SchemaDiff{
|
||||
Missing: []*models.Schema{{Name: "test"}},
|
||||
},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := formatJSON(result, &buf)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("formatJSON() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that JSON is indented (has newlines and spaces)
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "\n") {
|
||||
t.Error("formatJSON() should produce indented JSON with newlines")
|
||||
}
|
||||
if !strings.Contains(output, " ") {
|
||||
t.Error("formatJSON() should produce indented JSON with spaces")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputFormatConstants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
format OutputFormat
|
||||
want string
|
||||
}{
|
||||
{"summary constant", FormatSummary, "summary"},
|
||||
{"json constant", FormatJSON, "json"},
|
||||
{"html constant", FormatHTML, "html"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if string(tt.format) != tt.want {
|
||||
t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
238
pkg/inspector/inspector_test.go
Normal file
238
pkg/inspector/inspector_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewInspector(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
if inspector == nil {
|
||||
t.Fatal("NewInspector() returned nil")
|
||||
}
|
||||
|
||||
if inspector.db != db {
|
||||
t.Error("NewInspector() database not set correctly")
|
||||
}
|
||||
|
||||
if inspector.config != config {
|
||||
t.Error("NewInspector() config not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspect(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report == nil {
|
||||
t.Fatal("Inspect() returned nil report")
|
||||
}
|
||||
|
||||
if report.Database != db.Name {
|
||||
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
|
||||
}
|
||||
|
||||
if report.Summary.TotalRules != len(config.Rules) {
|
||||
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
|
||||
}
|
||||
|
||||
if len(report.Violations) == 0 {
|
||||
t.Error("Inspect() returned no violations, expected some results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithDisabledRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Disable all rules
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 0 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
if report.Summary.RulesSkipped != len(config.Rules) {
|
||||
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithEnforcedRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Enable only one rule and enforce it
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
primaryKeyRule := config.Rules["primary_key_naming"]
|
||||
primaryKeyRule.Enabled = "enforce"
|
||||
primaryKeyRule.Pattern = "^id$"
|
||||
config.Rules["primary_key_naming"] = primaryKeyRule
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 1 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
// All results should be at error level for enforced rules
|
||||
for _, violation := range report.Violations {
|
||||
if violation.Level != "error" {
|
||||
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSummary(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
results := []ValidationResult{
|
||||
{RuleName: "rule1", Passed: true, Level: "error"},
|
||||
{RuleName: "rule2", Passed: false, Level: "error"},
|
||||
{RuleName: "rule3", Passed: false, Level: "warning"},
|
||||
{RuleName: "rule4", Passed: true, Level: "warning"},
|
||||
}
|
||||
|
||||
summary := inspector.generateSummary(results)
|
||||
|
||||
if summary.PassedCount != 2 {
|
||||
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
|
||||
}
|
||||
|
||||
if summary.ErrorCount != 1 {
|
||||
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
|
||||
}
|
||||
|
||||
if summary.WarningCount != 1 {
|
||||
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report *InspectorReport
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "with errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 5,
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "without errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 0,
|
||||
WarningCount: 3,
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.report.HasErrors(); got != tt.want {
|
||||
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
functionName string
|
||||
wantExists bool
|
||||
}{
|
||||
{"primary_key_naming", "primary_key_naming", true},
|
||||
{"primary_key_datatype", "primary_key_datatype", true},
|
||||
{"foreign_key_column_naming", "foreign_key_column_naming", true},
|
||||
{"table_regexpr", "table_regexpr", true},
|
||||
{"column_regexpr", "column_regexpr", true},
|
||||
{"reserved_words", "reserved_words", true},
|
||||
{"have_primary_key", "have_primary_key", true},
|
||||
{"orphaned_foreign_key", "orphaned_foreign_key", true},
|
||||
{"circular_dependency", "circular_dependency", true},
|
||||
{"unknown_function", "unknown_function", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, exists := getValidator(tt.functionName)
|
||||
if exists != tt.wantExists {
|
||||
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateResult(t *testing.T) {
|
||||
result := createResult(
|
||||
"test_rule",
|
||||
true,
|
||||
"Test message",
|
||||
"schema.table.column",
|
||||
map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
},
|
||||
)
|
||||
|
||||
if result.RuleName != "test_rule" {
|
||||
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
|
||||
}
|
||||
|
||||
if !result.Passed {
|
||||
t.Error("createResult() Passed = false, want true")
|
||||
}
|
||||
|
||||
if result.Message != "Test message" {
|
||||
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
|
||||
}
|
||||
|
||||
if result.Location != "schema.table.column" {
|
||||
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
|
||||
}
|
||||
|
||||
if len(result.Context) != 2 {
|
||||
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
|
||||
}
|
||||
}
|
||||
366
pkg/inspector/report_test.go
Normal file
366
pkg/inspector/report_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTestReport() *InspectorReport {
|
||||
return &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 8,
|
||||
RulesSkipped: 2,
|
||||
ErrorCount: 3,
|
||||
WarningCount: 5,
|
||||
PassedCount: 12,
|
||||
},
|
||||
Violations: []ValidationResult{
|
||||
{
|
||||
RuleName: "primary_key_naming",
|
||||
Level: "error",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
Location: "public.users.user_id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "user_id",
|
||||
"pattern": "^id_",
|
||||
},
|
||||
},
|
||||
{
|
||||
RuleName: "table_name_length",
|
||||
Level: "warning",
|
||||
Message: "Table name too long",
|
||||
Location: "public.very_long_table_name_that_exceeds_limits",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "very_long_table_name_that_exceeds_limits",
|
||||
"length": 44,
|
||||
"max_length": 32,
|
||||
},
|
||||
},
|
||||
},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMarkdownFormatter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewMarkdownFormatter() returned nil")
|
||||
}
|
||||
|
||||
// Buffer is not a terminal, so colors should be disabled
|
||||
if formatter.UseColors {
|
||||
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJSONFormatter(t *testing.T) {
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewJSONFormatter() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Check that output contains expected sections
|
||||
if !strings.Contains(output, "# RelSpec Inspector Report") {
|
||||
t.Error("Markdown output missing header")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Database:") {
|
||||
t.Error("Markdown output missing database field")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "testdb") {
|
||||
t.Error("Markdown output missing database name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Summary") {
|
||||
t.Error("Markdown output missing summary section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Rules Checked: 8") {
|
||||
t.Error("Markdown output missing rules checked count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Errors: 3") {
|
||||
t.Error("Markdown output missing error count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Warnings: 5") {
|
||||
t.Error("Markdown output missing warning count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Violations") {
|
||||
t.Error("Markdown output missing violations section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "primary_key_naming") {
|
||||
t.Error("Markdown output missing rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "public.users.user_id") {
|
||||
t.Error("Markdown output missing location")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
|
||||
report := &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 10,
|
||||
RulesSkipped: 0,
|
||||
ErrorCount: 0,
|
||||
WarningCount: 0,
|
||||
PassedCount: 50,
|
||||
},
|
||||
Violations: []ValidationResult{},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "No violations found") {
|
||||
t.Error("Markdown output should indicate no violations")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON
|
||||
var decoded InspectorReport
|
||||
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check key fields
|
||||
if decoded.Database != "testdb" {
|
||||
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
|
||||
}
|
||||
|
||||
if decoded.Summary.ErrorCount != 3 {
|
||||
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
|
||||
}
|
||||
|
||||
if len(decoded.Violations) != 2 {
|
||||
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
header := formatter.formatHeader("Test Header")
|
||||
|
||||
if !strings.Contains(header, "# Test Header") {
|
||||
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatBold(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Bold Text",
|
||||
wantContains: "**Bold Text**",
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Bold Text",
|
||||
wantContains: "Bold Text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.formatBold(tt.text)
|
||||
|
||||
if !strings.Contains(result, tt.wantContains) {
|
||||
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Colorize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
color string
|
||||
wantColor bool
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: false,
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.colorize(tt.text, tt.color)
|
||||
|
||||
hasColor := strings.Contains(result, tt.color)
|
||||
if hasColor != tt.wantColor {
|
||||
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, tt.text) {
|
||||
t.Errorf("colorize() doesn't contain original text %q", tt.text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatContext(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
context := map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "id",
|
||||
"pattern": "^id_",
|
||||
"max_length": 64,
|
||||
}
|
||||
|
||||
result := formatter.formatContext(context)
|
||||
|
||||
// Should not include schema, table, column (they're in location)
|
||||
if strings.Contains(result, "schema") {
|
||||
t.Error("formatContext() should skip schema field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "table=") {
|
||||
t.Error("formatContext() should skip table field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "column=") {
|
||||
t.Error("formatContext() should skip column field")
|
||||
}
|
||||
|
||||
// Should include other fields
|
||||
if !strings.Contains(result, "pattern") {
|
||||
t.Error("formatContext() should include pattern field")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "max_length") {
|
||||
t.Error("formatContext() should include max_length field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
violation := ValidationResult{
|
||||
RuleName: "test_rule",
|
||||
Level: "error",
|
||||
Message: "Test violation message",
|
||||
Location: "public.users.id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"pattern": "^id_",
|
||||
},
|
||||
}
|
||||
|
||||
result := formatter.formatViolation(violation, colorRed)
|
||||
|
||||
if !strings.Contains(result, "test_rule") {
|
||||
t.Error("formatViolation() should include rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Test violation message") {
|
||||
t.Error("formatViolation() should include message")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "public.users.id") {
|
||||
t.Error("formatViolation() should include location")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Location:") {
|
||||
t.Error("formatViolation() should include Location label")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Message:") {
|
||||
t.Error("formatViolation() should include Message label")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportFormatConstants(t *testing.T) {
|
||||
// Test that color constants are defined
|
||||
if colorReset == "" {
|
||||
t.Error("colorReset is not defined")
|
||||
}
|
||||
|
||||
if colorRed == "" {
|
||||
t.Error("colorRed is not defined")
|
||||
}
|
||||
|
||||
if colorYellow == "" {
|
||||
t.Error("colorYellow is not defined")
|
||||
}
|
||||
|
||||
if colorGreen == "" {
|
||||
t.Error("colorGreen is not defined")
|
||||
}
|
||||
|
||||
if colorBold == "" {
|
||||
t.Error("colorBold is not defined")
|
||||
}
|
||||
}
|
||||
249
pkg/inspector/rules_test.go
Normal file
249
pkg/inspector/rules_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDefaultConfig(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("GetDefaultConfig() returned nil")
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("GetDefaultConfig() returned no rules")
|
||||
}
|
||||
|
||||
// Check that all expected rules are present
|
||||
expectedRules := []string{
|
||||
"primary_key_naming",
|
||||
"primary_key_datatype",
|
||||
"primary_key_auto_increment",
|
||||
"foreign_key_column_naming",
|
||||
"foreign_key_constraint_naming",
|
||||
"foreign_key_index",
|
||||
"table_naming_case",
|
||||
"column_naming_case",
|
||||
"table_name_length",
|
||||
"column_name_length",
|
||||
"reserved_keywords",
|
||||
"missing_primary_key",
|
||||
"orphaned_foreign_key",
|
||||
"circular_dependency",
|
||||
}
|
||||
|
||||
for _, ruleName := range expectedRules {
|
||||
if _, exists := config.Rules[ruleName]; !exists {
|
||||
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_NonExistentFile(t *testing.T) {
|
||||
// Try to load a non-existent file
|
||||
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should return default config
|
||||
if config == nil {
|
||||
t.Fatal("LoadConfig() returned nil config for non-existent file")
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("LoadConfig() returned config with no rules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ValidFile(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test-config.yaml")
|
||||
|
||||
configContent := `version: "1.0"
|
||||
rules:
|
||||
primary_key_naming:
|
||||
enabled: "enforce"
|
||||
function: "primary_key_naming"
|
||||
pattern: "^pk_"
|
||||
message: "Primary keys must start with pk_"
|
||||
table_name_length:
|
||||
enabled: "warn"
|
||||
function: "table_name_length"
|
||||
max_length: 50
|
||||
message: "Table name too long"
|
||||
`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() returned error: %v", err)
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) != 2 {
|
||||
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
|
||||
}
|
||||
|
||||
// Check primary_key_naming rule
|
||||
pkRule, exists := config.Rules["primary_key_naming"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing primary_key_naming rule")
|
||||
}
|
||||
|
||||
if pkRule.Enabled != "enforce" {
|
||||
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
|
||||
}
|
||||
|
||||
if pkRule.Pattern != "^pk_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
|
||||
}
|
||||
|
||||
// Check table_name_length rule
|
||||
lengthRule, exists := config.Rules["table_name_length"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing table_name_length rule")
|
||||
}
|
||||
|
||||
if lengthRule.MaxLength != 50 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_InvalidYAML(t *testing.T) {
|
||||
// Create a temporary invalid config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
|
||||
|
||||
invalidContent := `invalid: yaml: content: {[}]`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configPath)
|
||||
if err == nil {
|
||||
t.Error("LoadConfig() with invalid YAML did not return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enabled",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is enabled",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "off is not enabled",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty is not enabled",
|
||||
rule: Rule{Enabled: ""},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnabled(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enforced",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is not enforced",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "off is not enforced",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnforced(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigRuleSettings(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Test specific rule settings
|
||||
pkNamingRule := config.Rules["primary_key_naming"]
|
||||
if pkNamingRule.Function != "primary_key_naming" {
|
||||
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
|
||||
}
|
||||
|
||||
if pkNamingRule.Pattern != "^id_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
|
||||
}
|
||||
|
||||
// Test datatype rule
|
||||
pkDatatypeRule := config.Rules["primary_key_datatype"]
|
||||
if len(pkDatatypeRule.AllowedTypes) == 0 {
|
||||
t.Error("primary_key_datatype has no allowed types")
|
||||
}
|
||||
|
||||
// Test length rule
|
||||
tableLengthRule := config.Rules["table_name_length"]
|
||||
if tableLengthRule.MaxLength != 64 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
|
||||
}
|
||||
|
||||
// Test reserved keywords rule
|
||||
reservedRule := config.Rules["reserved_keywords"]
|
||||
if !reservedRule.CheckTables {
|
||||
t.Error("reserved_keywords.CheckTables should be true")
|
||||
}
|
||||
if !reservedRule.CheckColumns {
|
||||
t.Error("reserved_keywords.CheckColumns should be true")
|
||||
}
|
||||
}
|
||||
837
pkg/inspector/validators_test.go
Normal file
837
pkg/inspector/validators_test.go
Normal file
@@ -0,0 +1,837 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// Helper function to create test database
|
||||
func createTestDatabase() *models.Database {
|
||||
return &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"username": {
|
||||
Name: "username",
|
||||
Type: "varchar(50)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
"rid_organization": {
|
||||
Name: "rid_organization",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_users_organization": {
|
||||
Name: "fk_users_organization",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "organizations",
|
||||
ReferencedSchema: "public",
|
||||
ReferencedColumns: []string{"id"},
|
||||
},
|
||||
},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_rid_organization": {
|
||||
Name: "idx_rid_organization",
|
||||
Columns: []string{"rid_organization"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "organizations",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(100)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern id",
|
||||
rule: Rule{
|
||||
Pattern: "^id$",
|
||||
Message: "Primary key should be 'id'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern id_",
|
||||
rule: Rule{
|
||||
Pattern: "^id_",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyDatatype(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "allowed type bigserial",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"bigserial", "bigint", "int"},
|
||||
Message: "Primary key should use integer types",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "disallowed type",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"uuid"},
|
||||
Message: "Primary key should use UUID",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "require auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: true,
|
||||
Message: "Primary key should have auto-increment",
|
||||
},
|
||||
wantLen: 0, // No violations - all PKs have auto-increment
|
||||
},
|
||||
{
|
||||
name: "disallow auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: false,
|
||||
Message: "Primary key should not have auto-increment",
|
||||
},
|
||||
wantLen: 2, // 2 violations - both PKs have auto-increment
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyColumnNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern rid_",
|
||||
rule: Rule{
|
||||
Pattern: "^rid_",
|
||||
Message: "Foreign key columns should start with 'rid_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key columns should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key constraints should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern FK_",
|
||||
rule: Rule{
|
||||
Pattern: "^FK_",
|
||||
Message: "Foreign key constraints should start with 'FK_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyIndex(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "require index with index present",
|
||||
rule: Rule{
|
||||
RequireIndex: true,
|
||||
Message: "Foreign key columns should have indexes",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "no requirement",
|
||||
rule: Rule{
|
||||
RequireIndex: false,
|
||||
Message: "Foreign key index check disabled",
|
||||
},
|
||||
wantLen: 0,
|
||||
wantPass: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Table names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[A-Z][A-Z0-9_]*$",
|
||||
Case: "uppercase",
|
||||
Message: "Table names should be uppercase",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Column names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 5, // 5 total columns across both tables
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "camelCase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-zA-Z0-9]*$",
|
||||
Case: "camelCase",
|
||||
Message: "Column names should be camelCase",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // rid_organization has underscore
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // Some columns exceed 5 chars
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReservedKeywords(t *testing.T) {
|
||||
// Create a database with reserved keywords
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "user", // "user" is a reserved keyword
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
"select": { // "select" is a reserved keyword
|
||||
Name: "select",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
checkPasses bool
|
||||
}{
|
||||
{
|
||||
name: "check tables only",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: false,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 1, // "user" table
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check columns only",
|
||||
rule: Rule{
|
||||
CheckTables: false,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 2, // "id", "select" columns (id passes, select fails)
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check both",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 3, // "user" table + "id", "select" columns
|
||||
checkPasses: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateReservedKeywords(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMissingPrimaryKey(t *testing.T) {
|
||||
// Create database with and without primary keys
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "with_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Table missing primary key",
|
||||
}
|
||||
|
||||
results := validateMissingPrimaryKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
|
||||
}
|
||||
|
||||
// First result should pass (with_pk has PK)
|
||||
if results[0].Passed != true {
|
||||
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
|
||||
}
|
||||
|
||||
// Second result should fail (without_pk missing PK)
|
||||
if results[1].Passed != false {
|
||||
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOrphanedForeignKey(t *testing.T) {
|
||||
// Create database with orphaned FK
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_nonexistent": {
|
||||
Name: "fk_nonexistent",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "nonexistent_table",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Foreign key references non-existent table",
|
||||
}
|
||||
|
||||
results := validateOrphanedForeignKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
|
||||
}
|
||||
|
||||
if results[0].Passed != false {
|
||||
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCircularDependency(t *testing.T) {
|
||||
// Create database with circular dependency
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "table_a",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_b": {
|
||||
Name: "fk_to_b",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_b",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "table_b",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_a": {
|
||||
Name: "fk_to_a",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_a",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Circular dependency detected",
|
||||
}
|
||||
|
||||
results := validateCircularDependency(db, rule, "test_rule")
|
||||
|
||||
// Should detect circular dependency in both tables
|
||||
if len(results) == 0 {
|
||||
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
if result.Passed {
|
||||
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeDataType(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"varchar(50)", "varchar"},
|
||||
{"decimal(10,2)", "decimal"},
|
||||
{"int", "int"},
|
||||
{"BIGINT", "bigint"},
|
||||
{"VARCHAR(255)", "varchar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := normalizeDataType(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
value string
|
||||
expected bool
|
||||
}{
|
||||
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
|
||||
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
|
||||
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
|
||||
{"empty slice", []string{}, "foo", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.value)
|
||||
if result != tt.expected {
|
||||
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
graph map[string][]string
|
||||
node string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "simple cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {},
|
||||
},
|
||||
node: "A",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "self cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
visited := make(map[string]bool)
|
||||
recStack := make(map[string]bool)
|
||||
result := hasCycle(tt.node, tt.graph, visited, recStack)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatLocation(t *testing.T) {
|
||||
tests := []struct {
|
||||
schema string
|
||||
table string
|
||||
column string
|
||||
expected string
|
||||
}{
|
||||
{"public", "users", "id", "public.users.id"},
|
||||
{"public", "users", "", "public.users"},
|
||||
{"public", "", "", "public"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := formatLocation(tt.schema, tt.table, tt.column)
|
||||
if result != tt.expected {
|
||||
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
|
||||
tt.schema, tt.table, tt.column, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,14 +12,16 @@ import (
|
||||
|
||||
// MergeResult represents the result of a merge operation
|
||||
type MergeResult struct {
|
||||
SchemasAdded int
|
||||
TablesAdded int
|
||||
ColumnsAdded int
|
||||
RelationsAdded int
|
||||
DomainsAdded int
|
||||
EnumsAdded int
|
||||
ViewsAdded int
|
||||
SequencesAdded int
|
||||
SchemasAdded int
|
||||
TablesAdded int
|
||||
ColumnsAdded int
|
||||
ConstraintsAdded int
|
||||
IndexesAdded int
|
||||
RelationsAdded int
|
||||
DomainsAdded int
|
||||
EnumsAdded int
|
||||
ViewsAdded int
|
||||
SequencesAdded int
|
||||
}
|
||||
|
||||
// MergeOptions contains options for merge operations
|
||||
@@ -120,8 +122,10 @@ func (r *MergeResult) mergeTables(schema *models.Schema, source *models.Schema,
|
||||
}
|
||||
|
||||
if tgtTable, exists := existingTables[tableName]; exists {
|
||||
// Table exists, merge its columns
|
||||
// Table exists, merge its columns, constraints, and indexes
|
||||
r.mergeColumns(tgtTable, srcTable)
|
||||
r.mergeConstraints(tgtTable, srcTable)
|
||||
r.mergeIndexes(tgtTable, srcTable)
|
||||
} else {
|
||||
// Table doesn't exist, add it
|
||||
newTable := cloneTable(srcTable)
|
||||
@@ -151,6 +155,52 @@ func (r *MergeResult) mergeColumns(table *models.Table, srcTable *models.Table)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize constraints map if nil
|
||||
if table.Constraints == nil {
|
||||
table.Constraints = make(map[string]*models.Constraint)
|
||||
}
|
||||
|
||||
// Create map of existing constraints
|
||||
existingConstraints := make(map[string]*models.Constraint)
|
||||
for constName := range table.Constraints {
|
||||
existingConstraints[constName] = table.Constraints[constName]
|
||||
}
|
||||
|
||||
// Merge constraints
|
||||
for constName, srcConst := range srcTable.Constraints {
|
||||
if _, exists := existingConstraints[constName]; !exists {
|
||||
// Constraint doesn't exist, add it
|
||||
newConst := cloneConstraint(srcConst)
|
||||
table.Constraints[constName] = newConst
|
||||
r.ConstraintsAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table) {
|
||||
// Initialize indexes map if nil
|
||||
if table.Indexes == nil {
|
||||
table.Indexes = make(map[string]*models.Index)
|
||||
}
|
||||
|
||||
// Create map of existing indexes
|
||||
existingIndexes := make(map[string]*models.Index)
|
||||
for idxName := range table.Indexes {
|
||||
existingIndexes[idxName] = table.Indexes[idxName]
|
||||
}
|
||||
|
||||
// Merge indexes
|
||||
for idxName, srcIdx := range srcTable.Indexes {
|
||||
if _, exists := existingIndexes[idxName]; !exists {
|
||||
// Index doesn't exist, add it
|
||||
newIdx := cloneIndex(srcIdx)
|
||||
table.Indexes[idxName] = newIdx
|
||||
r.IndexesAdded++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergeResult) mergeViews(schema *models.Schema, source *models.Schema) {
|
||||
// Create map of existing views
|
||||
existingViews := make(map[string]*models.View)
|
||||
@@ -552,6 +602,8 @@ func GetMergeSummary(result *MergeResult) string {
|
||||
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
|
||||
fmt.Sprintf("Tables added: %d", result.TablesAdded),
|
||||
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
|
||||
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
|
||||
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
|
||||
fmt.Sprintf("Views added: %d", result.ViewsAdded),
|
||||
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
|
||||
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
|
||||
@@ -560,6 +612,7 @@ func GetMergeSummary(result *MergeResult) string {
|
||||
}
|
||||
|
||||
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
|
||||
result.ConstraintsAdded + result.IndexesAdded +
|
||||
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
|
||||
result.RelationsAdded + result.DomainsAdded
|
||||
|
||||
|
||||
617
pkg/merge/merge_test.go
Normal file
617
pkg/merge/merge_test.go
Normal file
@@ -0,0 +1,617 @@
|
||||
package merge
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
func TestMergeDatabases_NilInputs(t *testing.T) {
|
||||
result := MergeDatabases(nil, nil, nil)
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
if result.SchemasAdded != 0 {
|
||||
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDatabases_NewSchema(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "auth"},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.SchemasAdded != 1 {
|
||||
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
|
||||
}
|
||||
if len(target.Schemas) != 2 {
|
||||
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDatabases_ExistingSchema(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{Name: "public"},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.SchemasAdded != 0 {
|
||||
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
|
||||
}
|
||||
if len(target.Schemas) != 1 {
|
||||
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeTables_NewTable(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "posts",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.TablesAdded != 1 {
|
||||
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables) != 2 {
|
||||
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeColumns_NewColumn(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "int"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"email": {Name: "email", Type: "varchar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ColumnsAdded != 1 {
|
||||
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConstraints_NewConstraint(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ConstraintsAdded != 1 {
|
||||
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: nil, // Nil map
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ConstraintsAdded != 1 {
|
||||
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if target.Schemas[0].Tables[0].Constraints == nil {
|
||||
t.Error("Expected constraints map to be initialized")
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
|
||||
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeIndexes_NewIndex(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: nil, // Nil map
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
if target.Schemas[0].Tables[0].Indexes == nil {
|
||||
t.Error("Expected indexes map to be initialized")
|
||||
}
|
||||
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeOptions_SkipTableNames(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "migrations",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
opts := &MergeOptions{
|
||||
SkipTableNames: map[string]bool{
|
||||
"migrations": true,
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, opts)
|
||||
if result.TablesAdded != 0 {
|
||||
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Tables) != 1 {
|
||||
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeViews_NewView(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Views: []*models.View{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Views: []*models.View{
|
||||
{
|
||||
Name: "user_summary",
|
||||
Schema: "public",
|
||||
Definition: "SELECT * FROM users",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.ViewsAdded != 1 {
|
||||
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Views) != 1 {
|
||||
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeEnums_NewEnum(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Enums: []*models.Enum{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Enums: []*models.Enum{
|
||||
{
|
||||
Name: "user_role",
|
||||
Schema: "public",
|
||||
Values: []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.EnumsAdded != 1 {
|
||||
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Enums) != 1 {
|
||||
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeDomains_NewDomain(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Domains: []*models.Domain{},
|
||||
}
|
||||
source := &models.Database{
|
||||
Domains: []*models.Domain{
|
||||
{
|
||||
Name: "auth",
|
||||
Description: "Authentication domain",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.DomainsAdded != 1 {
|
||||
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
|
||||
}
|
||||
if len(target.Domains) != 1 {
|
||||
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeRelations_NewRelation(t *testing.T) {
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Relations: []*models.Relationship{},
|
||||
},
|
||||
},
|
||||
}
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Relations: []*models.Relationship{
|
||||
{
|
||||
Name: "fk_posts_user",
|
||||
Type: models.OneToMany,
|
||||
FromTable: "posts",
|
||||
FromColumns: []string{"user_id"},
|
||||
ToTable: "users",
|
||||
ToColumns: []string{"id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
if result.RelationsAdded != 1 {
|
||||
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
|
||||
}
|
||||
if len(target.Schemas[0].Relations) != 1 {
|
||||
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMergeSummary(t *testing.T) {
|
||||
result := &MergeResult{
|
||||
SchemasAdded: 1,
|
||||
TablesAdded: 2,
|
||||
ColumnsAdded: 5,
|
||||
ConstraintsAdded: 3,
|
||||
IndexesAdded: 2,
|
||||
ViewsAdded: 1,
|
||||
}
|
||||
|
||||
summary := GetMergeSummary(result)
|
||||
if summary == "" {
|
||||
t.Error("Expected non-empty summary")
|
||||
}
|
||||
if len(summary) < 50 {
|
||||
t.Errorf("Summary seems too short: %s", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMergeSummary_Nil(t *testing.T) {
|
||||
summary := GetMergeSummary(nil)
|
||||
if summary == "" {
|
||||
t.Error("Expected non-empty summary for nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexMerge(t *testing.T) {
|
||||
// Target with existing structure
|
||||
target := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "int"},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{},
|
||||
Indexes: map[string]*models.Index{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Source with new columns, constraints, and indexes
|
||||
source := &models.Database{
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Schema: "public",
|
||||
Columns: map[string]*models.Column{
|
||||
"email": {Name: "email", Type: "varchar"},
|
||||
"guid": {Name: "guid", Type: "uuid"},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"ukey_users_email": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"email"},
|
||||
Name: "ukey_users_email",
|
||||
},
|
||||
"ukey_users_guid": {
|
||||
Type: models.UniqueConstraint,
|
||||
Columns: []string{"guid"},
|
||||
Name: "ukey_users_guid",
|
||||
},
|
||||
},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_users_email": {
|
||||
Name: "idx_users_email",
|
||||
Columns: []string{"email"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := MergeDatabases(target, source, nil)
|
||||
|
||||
// Verify counts
|
||||
if result.ColumnsAdded != 2 {
|
||||
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
|
||||
}
|
||||
if result.ConstraintsAdded != 2 {
|
||||
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
|
||||
}
|
||||
if result.IndexesAdded != 1 {
|
||||
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
|
||||
}
|
||||
|
||||
// Verify target has merged data
|
||||
table := target.Schemas[0].Tables[0]
|
||||
if len(table.Columns) != 3 {
|
||||
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
|
||||
}
|
||||
if len(table.Constraints) != 2 {
|
||||
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
|
||||
}
|
||||
if len(table.Indexes) != 1 {
|
||||
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
|
||||
}
|
||||
|
||||
// Verify specific constraint
|
||||
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
|
||||
t.Error("Expected ukey_users_guid constraint to exist")
|
||||
}
|
||||
}
|
||||
339
pkg/pgsql/datatypes_test.go
Normal file
339
pkg/pgsql/datatypes_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqltype string
|
||||
want bool
|
||||
}{
|
||||
// PostgreSQL types
|
||||
{"Valid PGSQL bigint", "bigint", true},
|
||||
{"Valid PGSQL integer", "integer", true},
|
||||
{"Valid PGSQL text", "text", true},
|
||||
{"Valid PGSQL boolean", "boolean", true},
|
||||
{"Valid PGSQL double precision", "double precision", true},
|
||||
{"Valid PGSQL bytea", "bytea", true},
|
||||
{"Valid PGSQL uuid", "uuid", true},
|
||||
{"Valid PGSQL jsonb", "jsonb", true},
|
||||
{"Valid PGSQL json", "json", true},
|
||||
{"Valid PGSQL timestamp", "timestamp", true},
|
||||
{"Valid PGSQL date", "date", true},
|
||||
{"Valid PGSQL time", "time", true},
|
||||
{"Valid PGSQL citext", "citext", true},
|
||||
|
||||
// Standard types
|
||||
{"Valid std double", "double", true},
|
||||
{"Valid std blob", "blob", true},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BIGINT", "BIGINT", true},
|
||||
{"Case insensitive TeXt", "TeXt", true},
|
||||
{"Case insensitive BoOlEaN", "BoOlEaN", true},
|
||||
|
||||
// Invalid types
|
||||
{"Invalid type", "invalidtype", false},
|
||||
{"Invalid type varchar", "varchar", false},
|
||||
{"Empty string", "", false},
|
||||
{"Random string", "foobar", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ValidSQLType(tt.sqltype)
|
||||
if got != tt.want {
|
||||
t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anytype string
|
||||
want string
|
||||
}{
|
||||
// Go types to PostgreSQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double precision", "float64", "double precision"},
|
||||
{"Go float32 to real", "float32", "real"},
|
||||
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||
|
||||
// SQL types remain SQL types
|
||||
{"SQL bigint", "bigint", "bigint"},
|
||||
{"SQL integer", "integer", "integer"},
|
||||
{"SQL text", "text", "text"},
|
||||
{"SQL boolean", "boolean", "boolean"},
|
||||
{"SQL uuid", "uuid", "uuid"},
|
||||
{"SQL jsonb", "jsonb", "jsonb"},
|
||||
|
||||
// Case insensitive Go types
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
{"Case insensitive STRING", "STRING", "text"},
|
||||
|
||||
// Case insensitive SQL types
|
||||
{"Case insensitive BIGINT", "BIGINT", "bigint"},
|
||||
{"Case insensitive TEXT", "TEXT", "text"},
|
||||
|
||||
// Custom types
|
||||
{"Custom sqluuid", "sqluuid", "uuid"},
|
||||
{"Custom sqljsonb", "sqljsonb", "jsonb"},
|
||||
{"Custom sqlint64", "sqlint64", "bigint"},
|
||||
|
||||
// Unknown types default to text
|
||||
{"Unknown type varchar", "varchar", "text"},
|
||||
{"Unknown type foobar", "foobar", "text"},
|
||||
{"Empty string", "", "text"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetSQLType(tt.anytype)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSQLType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anytype string
|
||||
want string
|
||||
}{
|
||||
// Go types to PostgreSQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double precision", "float64", "double precision"},
|
||||
{"Go float32 to real", "float32", "real"},
|
||||
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||
|
||||
// SQL types remain SQL types
|
||||
{"SQL bigint", "bigint", "bigint"},
|
||||
{"SQL integer", "integer", "integer"},
|
||||
{"SQL text", "text", "text"},
|
||||
{"SQL boolean", "boolean", "boolean"},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
|
||||
// Unknown types remain unchanged (difference from GetSQLType)
|
||||
{"Unknown type varchar", "varchar", "varchar"},
|
||||
{"Unknown type foobar", "foobar", "foobar"},
|
||||
{"Empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ConvertSQLType(tt.anytype)
|
||||
if got != tt.want {
|
||||
t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGoType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
typeName string
|
||||
want bool
|
||||
}{
|
||||
// Go basic types
|
||||
{"Go bool", "bool", true},
|
||||
{"Go int64", "int64", true},
|
||||
{"Go int", "int", true},
|
||||
{"Go int32", "int32", true},
|
||||
{"Go int16", "int16", true},
|
||||
{"Go int8", "int8", true},
|
||||
{"Go uint", "uint", true},
|
||||
{"Go uint64", "uint64", true},
|
||||
{"Go uint32", "uint32", true},
|
||||
{"Go uint16", "uint16", true},
|
||||
{"Go uint8", "uint8", true},
|
||||
{"Go float64", "float64", true},
|
||||
{"Go float32", "float32", true},
|
||||
{"Go string", "string", true},
|
||||
{"Go []byte", "[]byte", true},
|
||||
|
||||
// Go custom types
|
||||
{"Go complex64", "complex64", true},
|
||||
{"Go complex128", "complex128", true},
|
||||
{"Go uintptr", "uintptr", true},
|
||||
{"Go Pointer", "Pointer", true},
|
||||
|
||||
// Custom SQL types
|
||||
{"Custom sqluuid", "sqluuid", true},
|
||||
{"Custom sqljsonb", "sqljsonb", true},
|
||||
{"Custom sqlint64", "sqlint64", true},
|
||||
{"Custom customdate", "customdate", true},
|
||||
{"Custom customtime", "customtime", true},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", true},
|
||||
{"Case insensitive InT64", "InT64", true},
|
||||
{"Case insensitive STRING", "STRING", true},
|
||||
|
||||
// SQL types (not Go types)
|
||||
{"SQL bigint", "bigint", false},
|
||||
{"SQL integer", "integer", false},
|
||||
{"SQL text", "text", false},
|
||||
{"SQL boolean", "boolean", false},
|
||||
|
||||
// Invalid types
|
||||
{"Invalid type", "invalidtype", false},
|
||||
{"Empty string", "", false},
|
||||
{"Random string", "foobar", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsGoType(tt.typeName)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStdTypeFromGo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
typeName string
|
||||
want string
|
||||
}{
|
||||
// Go types to standard SQL types
|
||||
{"Go bool to boolean", "bool", "boolean"},
|
||||
{"Go int64 to bigint", "int64", "bigint"},
|
||||
{"Go int to integer", "int", "integer"},
|
||||
{"Go string to text", "string", "text"},
|
||||
{"Go float64 to double", "float64", "double"},
|
||||
{"Go float32 to double", "float32", "double"},
|
||||
{"Go []byte to blob", "[]byte", "blob"},
|
||||
{"Go int32 to integer", "int32", "integer"},
|
||||
{"Go int16 to smallint", "int16", "smallint"},
|
||||
|
||||
// Custom types
|
||||
{"Custom sqluuid to uuid", "sqluuid", "uuid"},
|
||||
{"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"},
|
||||
{"Custom sqlint64 to bigint", "sqlint64", "bigint"},
|
||||
{"Custom customdate to date", "customdate", "date"},
|
||||
|
||||
// Case insensitive
|
||||
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||
{"Case insensitive InT64", "InT64", "bigint"},
|
||||
{"Case insensitive STRING", "STRING", "text"},
|
||||
|
||||
// Non-Go types remain unchanged
|
||||
{"SQL bigint unchanged", "bigint", "bigint"},
|
||||
{"SQL integer unchanged", "integer", "integer"},
|
||||
{"Invalid type unchanged", "invalidtype", "invalidtype"},
|
||||
{"Empty string unchanged", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetStdTypeFromGo(tt.typeName)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoToStdTypesMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"string": "text",
|
||||
"float64": "double",
|
||||
"[]byte": "blob",
|
||||
}
|
||||
|
||||
for goType, expectedStd := range expectedMappings {
|
||||
if stdType, ok := GoToStdTypes[goType]; !ok {
|
||||
t.Errorf("GoToStdTypes missing entry for %q", goType)
|
||||
} else if stdType != expectedStd {
|
||||
t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the map is not empty
|
||||
if len(GoToStdTypes) == 0 {
|
||||
t.Error("GoToStdTypes map is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoToPGSQLTypesMap(t *testing.T) {
|
||||
// Test that the map contains expected entries
|
||||
expectedMappings := map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"string": "text",
|
||||
"float64": "double precision",
|
||||
"float32": "real",
|
||||
"[]byte": "bytea",
|
||||
}
|
||||
|
||||
for goType, expectedPG := range expectedMappings {
|
||||
if pgType, ok := GoToPGSQLTypes[goType]; !ok {
|
||||
t.Errorf("GoToPGSQLTypes missing entry for %q", goType)
|
||||
} else if pgType != expectedPG {
|
||||
t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the map is not empty
|
||||
if len(GoToPGSQLTypes) == 0 {
|
||||
t.Error("GoToPGSQLTypes map is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeConversionConsistency(t *testing.T) {
|
||||
// Test that GetSQLType and ConvertSQLType are consistent for known types
|
||||
knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"}
|
||||
|
||||
for _, goType := range knownGoTypes {
|
||||
getSQLResult := GetSQLType(goType)
|
||||
convertResult := ConvertSQLType(goType)
|
||||
|
||||
if getSQLResult != convertResult {
|
||||
t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q",
|
||||
goType, getSQLResult, convertResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) {
|
||||
// Test that GetSQLType returns "text" for unknown types
|
||||
// while ConvertSQLType returns the input unchanged
|
||||
unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"}
|
||||
|
||||
for _, unknown := range unknownTypes {
|
||||
getSQLResult := GetSQLType(unknown)
|
||||
convertResult := ConvertSQLType(unknown)
|
||||
|
||||
if getSQLResult != "text" {
|
||||
t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text")
|
||||
}
|
||||
|
||||
if convertResult != unknown {
|
||||
t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
136
pkg/pgsql/keywords_test.go
Normal file
136
pkg/pgsql/keywords_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetPostgresKeywords(t *testing.T) {
|
||||
keywords := GetPostgresKeywords()
|
||||
|
||||
// Test that keywords are returned
|
||||
if len(keywords) == 0 {
|
||||
t.Fatal("Expected non-empty list of keywords")
|
||||
}
|
||||
|
||||
// Test that we get all keywords from the map
|
||||
expectedCount := len(postgresKeywords)
|
||||
if len(keywords) != expectedCount {
|
||||
t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords))
|
||||
}
|
||||
|
||||
// Test that all returned keywords exist in the map
|
||||
for _, keyword := range keywords {
|
||||
if !postgresKeywords[keyword] {
|
||||
t.Errorf("Keyword %q not found in postgresKeywords map", keyword)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that no duplicate keywords are returned
|
||||
seen := make(map[string]bool)
|
||||
for _, keyword := range keywords {
|
||||
if seen[keyword] {
|
||||
t.Errorf("Duplicate keyword found: %q", keyword)
|
||||
}
|
||||
seen[keyword] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyword string
|
||||
want bool
|
||||
}{
|
||||
{"SELECT keyword", "select", true},
|
||||
{"FROM keyword", "from", true},
|
||||
{"WHERE keyword", "where", true},
|
||||
{"TABLE keyword", "table", true},
|
||||
{"PRIMARY keyword", "primary", true},
|
||||
{"FOREIGN keyword", "foreign", true},
|
||||
{"CREATE keyword", "create", true},
|
||||
{"DROP keyword", "drop", true},
|
||||
{"ALTER keyword", "alter", true},
|
||||
{"INDEX keyword", "index", true},
|
||||
{"NOT keyword", "not", true},
|
||||
{"NULL keyword", "null", true},
|
||||
{"TRUE keyword", "true", true},
|
||||
{"FALSE keyword", "false", true},
|
||||
{"Non-keyword lowercase", "notakeyword", false},
|
||||
{"Non-keyword uppercase", "NOTAKEYWORD", false},
|
||||
{"Empty string", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := postgresKeywords[tt.keyword]
|
||||
if got != tt.want {
|
||||
t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMapContent(t *testing.T) {
|
||||
// Test that the map contains expected common keywords
|
||||
commonKeywords := []string{
|
||||
"select", "insert", "update", "delete", "create", "drop", "alter",
|
||||
"table", "index", "view", "schema", "function", "procedure",
|
||||
"primary", "foreign", "key", "constraint", "unique", "check",
|
||||
"null", "not", "and", "or", "like", "in", "between",
|
||||
"join", "inner", "left", "right", "cross", "full", "outer",
|
||||
"where", "having", "group", "order", "limit", "offset",
|
||||
"union", "intersect", "except",
|
||||
"begin", "commit", "rollback", "transaction",
|
||||
}
|
||||
|
||||
for _, keyword := range commonKeywords {
|
||||
if !postgresKeywords[keyword] {
|
||||
t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresKeywordsMapSize(t *testing.T) {
|
||||
// PostgreSQL has a substantial list of reserved keywords
|
||||
// This test ensures the map has a reasonable number of entries
|
||||
minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words
|
||||
|
||||
if len(postgresKeywords) < minExpectedKeywords {
|
||||
t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.",
|
||||
minExpectedKeywords, len(postgresKeywords))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPostgresKeywordsConsistency(t *testing.T) {
|
||||
// Test that calling GetPostgresKeywords multiple times returns consistent results
|
||||
keywords1 := GetPostgresKeywords()
|
||||
keywords2 := GetPostgresKeywords()
|
||||
|
||||
if len(keywords1) != len(keywords2) {
|
||||
t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d",
|
||||
len(keywords1), len(keywords2))
|
||||
}
|
||||
|
||||
// Create a map from both results to compare
|
||||
map1 := make(map[string]bool)
|
||||
map2 := make(map[string]bool)
|
||||
|
||||
for _, k := range keywords1 {
|
||||
map1[k] = true
|
||||
}
|
||||
for _, k := range keywords2 {
|
||||
map2[k] = true
|
||||
}
|
||||
|
||||
// Check that both contain the same keywords
|
||||
for k := range map1 {
|
||||
if !map2[k] {
|
||||
t.Errorf("Keyword %q present in first call but not in second", k)
|
||||
}
|
||||
}
|
||||
for k := range map2 {
|
||||
if !map1[k] {
|
||||
t.Errorf("Keyword %q present in second call but not in first", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -603,8 +603,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
column.Default = strings.Trim(defaultVal, "'\"")
|
||||
} else if attr == "unique" {
|
||||
// Create a unique constraint
|
||||
// Clean table name by removing leading underscores to avoid double underscores
|
||||
cleanTableName := strings.TrimLeft(tableName, "_")
|
||||
uniqueConstraint := models.InitConstraint(
|
||||
fmt.Sprintf("uq_%s", columnName),
|
||||
fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName),
|
||||
models.UniqueConstraint,
|
||||
)
|
||||
uniqueConstraint.Schema = schemaName
|
||||
@@ -652,8 +654,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = []string{columnName}
|
||||
}
|
||||
// Generate short constraint name based on the column
|
||||
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
|
||||
// Generate constraint name based on table and columns
|
||||
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -737,7 +739,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
||||
|
||||
// Generate name if not provided
|
||||
if index.Name == "" {
|
||||
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
|
||||
prefix := "idx"
|
||||
if index.Unique {
|
||||
prefix = "uidx"
|
||||
}
|
||||
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
|
||||
}
|
||||
|
||||
return index
|
||||
@@ -797,10 +803,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate short constraint name based on the source column
|
||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||
if len(fromColumns) > 0 {
|
||||
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
|
||||
// Generate constraint name based on table and columns
|
||||
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
|
||||
if len(fromColumns) == 0 {
|
||||
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(
|
||||
|
||||
@@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstraintNaming(t *testing.T) {
|
||||
// Test that constraints are named with proper prefixes
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
// Find users table
|
||||
var usersTable *models.Table
|
||||
var postsTable *models.Table
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == "users" {
|
||||
usersTable = table
|
||||
} else if table.Name == "posts" {
|
||||
postsTable = table
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usersTable == nil {
|
||||
t.Fatal("Users table not found")
|
||||
}
|
||||
if postsTable == nil {
|
||||
t.Fatal("Posts table not found")
|
||||
}
|
||||
|
||||
// Test unique constraint naming: ukey_table_column
|
||||
if _, exists := usersTable.Constraints["ukey_users_email"]; !exists {
|
||||
t.Error("Expected unique constraint 'ukey_users_email' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
|
||||
}
|
||||
|
||||
if _, exists := postsTable.Constraints["ukey_posts_slug"]; !exists {
|
||||
t.Error("Expected unique constraint 'ukey_posts_slug' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||
}
|
||||
|
||||
// Test foreign key naming: fk_table_column
|
||||
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
|
||||
t.Error("Expected foreign key 'fk_posts_user_id' not found")
|
||||
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
|
||||
}
|
||||
|
||||
// Test unique index naming: uidx_table_columns
|
||||
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
|
||||
t.Error("Expected unique index 'uidx_posts_slug' not found")
|
||||
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||
}
|
||||
|
||||
// Test regular index naming: idx_table_columns
|
||||
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
|
||||
t.Error("Expected index 'idx_posts_user_id_published' not found")
|
||||
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
|
||||
}
|
||||
}
|
||||
|
||||
func getKeys[V any](m map[string]V) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func TestHasCommentedRefs(t *testing.T) {
|
||||
// Test with the actual multifile test fixtures
|
||||
tests := []struct {
|
||||
|
||||
490
pkg/reflectutil/helpers_test.go
Normal file
490
pkg/reflectutil/helpers_test.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package reflectutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
Active bool
|
||||
Nested *nestedStruct
|
||||
Private string
|
||||
}
|
||||
|
||||
type nestedStruct struct {
|
||||
Value string
|
||||
Count int
|
||||
}
|
||||
|
||||
func TestDeref(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantValid bool
|
||||
wantKind reflect.Kind
|
||||
}{
|
||||
{
|
||||
name: "non-pointer int",
|
||||
input: 42,
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "single pointer",
|
||||
input: ptrInt(42),
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "double pointer",
|
||||
input: ptrPtr(ptrInt(42)),
|
||||
wantValid: true,
|
||||
wantKind: reflect.Int,
|
||||
},
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: (*int)(nil),
|
||||
wantValid: false,
|
||||
wantKind: reflect.Ptr,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
input: "test",
|
||||
wantValid: true,
|
||||
wantKind: reflect.String,
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
input: testStruct{Name: "test"},
|
||||
wantValid: true,
|
||||
wantKind: reflect.Struct,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := reflect.ValueOf(tt.input)
|
||||
got, valid := Deref(v)
|
||||
|
||||
if valid != tt.wantValid {
|
||||
t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid)
|
||||
}
|
||||
|
||||
if got.Kind() != tt.wantKind {
|
||||
t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDerefInterface(t *testing.T) {
|
||||
i := 42
|
||||
pi := &i
|
||||
ppi := &pi
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantKind reflect.Kind
|
||||
}{
|
||||
{"int", 42, reflect.Int},
|
||||
{"pointer to int", &i, reflect.Int},
|
||||
{"double pointer to int", ppi, reflect.Int},
|
||||
{"string", "test", reflect.String},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DerefInterface(tt.input)
|
||||
if got.Kind() != tt.wantKind {
|
||||
t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFieldValue(t *testing.T) {
|
||||
ts := testStruct{
|
||||
Name: "John",
|
||||
Age: 30,
|
||||
Active: true,
|
||||
Nested: &nestedStruct{Value: "nested", Count: 5},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
item interface{}
|
||||
field string
|
||||
want interface{}
|
||||
}{
|
||||
{"struct field Name", ts, "Name", "John"},
|
||||
{"struct field Age", ts, "Age", 30},
|
||||
{"struct field Active", ts, "Active", true},
|
||||
{"struct non-existent field", ts, "NonExistent", nil},
|
||||
{"pointer to struct", &ts, "Name", "John"},
|
||||
{"map string key", map[string]string{"key": "value"}, "key", "value"},
|
||||
{"map int key", map[string]int{"count": 42}, "count", 42},
|
||||
{"map non-existent key", map[string]string{"key": "value"}, "missing", nil},
|
||||
{"nil pointer", (*testStruct)(nil), "Name", nil},
|
||||
{"non-struct non-map", 42, "field", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetFieldValue(tt.item, tt.field)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetFieldValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSliceOrArray(t *testing.T) {
|
||||
arr := [3]int{1, 2, 3}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want bool
|
||||
}{
|
||||
{"slice", []int{1, 2, 3}, true},
|
||||
{"array", arr, true},
|
||||
{"pointer to slice", &[]int{1, 2, 3}, true},
|
||||
{"string", "test", false},
|
||||
{"int", 42, false},
|
||||
{"map", map[string]int{}, false},
|
||||
{"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice
|
||||
{"nil pointer", (*[]int)(nil), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsSliceOrArray(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want bool
|
||||
}{
|
||||
{"map[string]int", map[string]int{"a": 1}, true},
|
||||
{"map[int]string", map[int]string{1: "a"}, true},
|
||||
{"pointer to map", &map[string]int{"a": 1}, true},
|
||||
{"slice", []int{1, 2, 3}, false},
|
||||
{"string", "test", false},
|
||||
{"int", 42, false},
|
||||
{"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map
|
||||
{"nil pointer", (*map[string]int)(nil), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsMap(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsMap() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceLen(t *testing.T) {
|
||||
arr := [3]int{1, 2, 3}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int
|
||||
}{
|
||||
{"slice length 3", []int{1, 2, 3}, 3},
|
||||
{"empty slice", []int{}, 0},
|
||||
{"array length 3", arr, 3},
|
||||
{"pointer to slice", &[]int{1, 2, 3}, 3},
|
||||
{"not a slice", "test", 0},
|
||||
{"int", 42, 0},
|
||||
{"nil slice", ([]int)(nil), 0},
|
||||
{"nil pointer", (*[]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceLen(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("SliceLen() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int
|
||||
}{
|
||||
{"map length 2", map[string]int{"a": 1, "b": 2}, 2},
|
||||
{"empty map", map[string]int{}, 0},
|
||||
{"pointer to map", &map[string]int{"a": 1}, 1},
|
||||
{"not a map", []int{1, 2, 3}, 0},
|
||||
{"string", "test", 0},
|
||||
{"nil map", (map[string]int)(nil), 0},
|
||||
{"nil pointer", (*map[string]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapLen(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapLen() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceToInterfaces(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want []interface{}
|
||||
}{
|
||||
{"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}},
|
||||
{"string slice", []string{"a", "b"}, []interface{}{"a", "b"}},
|
||||
{"empty slice", []int{}, []interface{}{}},
|
||||
{"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}},
|
||||
{"not a slice", "test", []interface{}{}},
|
||||
{"nil slice", ([]int)(nil), []interface{}{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceToInterfaces(tt.input)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want []interface{}
|
||||
}{
|
||||
{"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}},
|
||||
{"empty map", map[string]int{}, []interface{}{}},
|
||||
{"not a map", []int{1, 2, 3}, []interface{}{}},
|
||||
{"nil map", (map[string]int)(nil), []interface{}{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapKeys(tt.input)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want))
|
||||
}
|
||||
// For maps, order is not guaranteed, so just check length
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
want int // length of values
|
||||
}{
|
||||
{"map with values", map[string]int{"a": 1, "b": 2}, 2},
|
||||
{"empty map", map[string]int{}, 0},
|
||||
{"not a map", []int{1, 2, 3}, 0},
|
||||
{"nil map", (map[string]int)(nil), 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapValues(tt.input)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("MapValues() length = %v, want %v", len(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapGet(t *testing.T) {
|
||||
m := map[string]int{"a": 1, "b": 2}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
key interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"existing key", m, "a", 1},
|
||||
{"existing key b", m, "b", 2},
|
||||
{"non-existing key", m, "c", nil},
|
||||
{"pointer to map", &m, "a", 1},
|
||||
{"not a map", []int{1, 2}, 0, nil},
|
||||
{"nil map", (map[string]int)(nil), "a", nil},
|
||||
{"nil pointer", (*map[string]int)(nil), "a", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MapGet(tt.input, tt.key)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("MapGet() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceIndex(t *testing.T) {
|
||||
s := []int{10, 20, 30}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
slice interface{}
|
||||
index int
|
||||
want interface{}
|
||||
}{
|
||||
{"index 0", s, 0, 10},
|
||||
{"index 1", s, 1, 20},
|
||||
{"index 2", s, 2, 30},
|
||||
{"negative index", s, -1, nil},
|
||||
{"out of bounds", s, 5, nil},
|
||||
{"pointer to slice", &s, 1, 20},
|
||||
{"not a slice", "test", 0, nil},
|
||||
{"nil slice", ([]int)(nil), 0, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SliceIndex(tt.slice, tt.index)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SliceIndex() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a interface{}
|
||||
b interface{}
|
||||
want int
|
||||
}{
|
||||
{"both nil", nil, nil, 0},
|
||||
{"a nil", nil, 5, -1},
|
||||
{"b nil", 5, nil, 1},
|
||||
{"equal strings", "abc", "abc", 0},
|
||||
{"a less than b strings", "abc", "xyz", -1},
|
||||
{"a greater than b strings", "xyz", "abc", 1},
|
||||
{"equal ints", 5, 5, 0},
|
||||
{"a less than b ints", 3, 7, -1},
|
||||
{"a greater than b ints", 10, 5, 1},
|
||||
{"equal floats", 3.14, 3.14, 0},
|
||||
{"a less than b floats", 2.5, 5.5, -1},
|
||||
{"a greater than b floats", 10.5, 5.5, 1},
|
||||
{"equal uints", uint(5), uint(5), 0},
|
||||
{"different types", "abc", 123, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CompareValues(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNestedValue(t *testing.T) {
|
||||
nested := map[string]interface{}{
|
||||
"level1": map[string]interface{}{
|
||||
"level2": map[string]interface{}{
|
||||
"value": "deep",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ts := testStruct{
|
||||
Name: "John",
|
||||
Nested: &nestedStruct{
|
||||
Value: "nested value",
|
||||
Count: 42,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
path string
|
||||
want interface{}
|
||||
}{
|
||||
{"empty path", nested, "", nested},
|
||||
{"single level map", nested, "level1", nested["level1"]},
|
||||
{"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}},
|
||||
{"deep nested map", nested, "level1.level2.value", "deep"},
|
||||
{"struct field", ts, "Name", "John"},
|
||||
{"nested struct field", ts, "Nested", ts.Nested},
|
||||
{"non-existent path", nested, "missing.path", nil},
|
||||
{"nil input", nil, "path", nil},
|
||||
{"partial missing path", nested, "level1.missing", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetNestedValue(tt.input, tt.path)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetNestedValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a interface{}
|
||||
b interface{}
|
||||
want bool
|
||||
}{
|
||||
{"equal ints", 42, 42, true},
|
||||
{"different ints", 42, 43, false},
|
||||
{"equal strings", "test", "test", true},
|
||||
{"different strings", "test", "other", false},
|
||||
{"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true},
|
||||
{"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false},
|
||||
{"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true},
|
||||
{"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false},
|
||||
{"both nil", nil, nil, true},
|
||||
{"one nil", nil, 42, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DeepEqual(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func ptrInt(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
func ptrPtr(p *int) **int {
|
||||
return &p
|
||||
}
|
||||
@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
|
||||
t.Error("SQL missing DELETE handling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
|
||||
// Current database (empty)
|
||||
current := models.InitDatabase("testdb")
|
||||
currentSchema := models.InitSchema("entity")
|
||||
current.Schemas = append(current.Schemas, currentSchema)
|
||||
|
||||
// Model database (with constraint starting with number)
|
||||
model := models.InitDatabase("testdb")
|
||||
modelSchema := models.InitSchema("entity")
|
||||
|
||||
// Create individual_actor_relationship table
|
||||
table := models.InitTable("individual_actor_relationship", "entity")
|
||||
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
|
||||
idCol.Type = "integer"
|
||||
idCol.IsPrimaryKey = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
|
||||
actorIDCol.Type = "integer"
|
||||
table.Columns["actor_id"] = actorIDCol
|
||||
|
||||
// Add constraint with name starting with number
|
||||
constraint := &models.Constraint{
|
||||
Name: "215162_fk_actor",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"actor_id"},
|
||||
ReferencedSchema: "entity",
|
||||
ReferencedTable: "actor",
|
||||
ReferencedColumns: []string{"id"},
|
||||
OnDelete: "CASCADE",
|
||||
OnUpdate: "NO ACTION",
|
||||
}
|
||||
table.Constraints["215162_fk_actor"] = constraint
|
||||
|
||||
modelSchema.Tables = append(modelSchema.Tables, table)
|
||||
model.Schemas = append(model.Schemas, modelSchema)
|
||||
|
||||
// Generate migration
|
||||
var buf bytes.Buffer
|
||||
writer, err := NewMigrationWriter(&writers.WriterOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create writer: %v", err)
|
||||
}
|
||||
writer.writer = &buf
|
||||
|
||||
err = writer.WriteMigration(model, current)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteMigration failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
t.Logf("Generated migration:\n%s", output)
|
||||
|
||||
// Verify constraint name is properly quoted
|
||||
if !strings.Contains(output, `"215162_fk_actor"`) {
|
||||
t.Error("Constraint name starting with number should be quoted")
|
||||
}
|
||||
|
||||
// Verify the SQL is syntactically correct (contains required keywords)
|
||||
if !strings.Contains(output, "ADD CONSTRAINT") {
|
||||
t.Error("Migration missing ADD CONSTRAINT")
|
||||
}
|
||||
if !strings.Contains(output, "FOREIGN KEY") {
|
||||
t.Error("Migration missing FOREIGN KEY")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
|
||||
"quote": quote,
|
||||
"escape": escape,
|
||||
"safe_identifier": safeIdentifier,
|
||||
"quote_ident": quoteIdent,
|
||||
|
||||
// Type conversion
|
||||
"goTypeToSQL": goTypeToSQL,
|
||||
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
|
||||
return strings.ToLower(safe)
|
||||
}
|
||||
|
||||
// quoteIdent quotes a PostgreSQL identifier if necessary
|
||||
// Identifiers need quoting if they:
|
||||
// - Start with a digit
|
||||
// - Contain special characters
|
||||
// - Are reserved keywords
|
||||
// - Contain uppercase letters (to preserve case)
|
||||
func quoteIdent(s string) string {
|
||||
if s == "" {
|
||||
return `""`
|
||||
}
|
||||
|
||||
// Check if quoting is needed
|
||||
needsQuoting := unicode.IsDigit(rune(s[0]))
|
||||
|
||||
// Starts with digit
|
||||
|
||||
// Contains uppercase letters or special characters
|
||||
for _, r := range s {
|
||||
if unicode.IsUpper(r) {
|
||||
needsQuoting = true
|
||||
break
|
||||
}
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
|
||||
needsQuoting = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsQuoting {
|
||||
// Escape double quotes by doubling them
|
||||
escaped := strings.ReplaceAll(s, `"`, `""`)
|
||||
return `"` + escaped + `"`
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Type conversion functions
|
||||
|
||||
// goTypeToSQL converts Go type to PostgreSQL type
|
||||
|
||||
@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteIdent(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"valid_name", "valid_name"},
|
||||
{"ValidName", `"ValidName"`},
|
||||
{"123column", `"123column"`},
|
||||
{"215162_fk_constraint", `"215162_fk_constraint"`},
|
||||
{"user-id", `"user-id"`},
|
||||
{"user@domain", `"user@domain"`},
|
||||
{`"quoted"`, `"""quoted"""`},
|
||||
{"", `""`},
|
||||
{"lowercase", "lowercase"},
|
||||
{"with_underscore", "with_underscore"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := quoteIdent(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoTypeToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
|
||||
// Check that all expected functions are registered
|
||||
expectedFuncs := []string{
|
||||
"upper", "lower", "snake_case", "camelCase",
|
||||
"indent", "quote", "escape", "safe_identifier",
|
||||
"indent", "quote", "escape", "safe_identifier", "quote_ident",
|
||||
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
||||
"first", "last", "filter", "mapFunc", "join_with",
|
||||
"join",
|
||||
|
||||
@@ -177,6 +177,72 @@ type AuditTriggerData struct {
|
||||
Events string
|
||||
}
|
||||
|
||||
// CreateUniqueConstraintData contains data for create unique constraint template
|
||||
type CreateUniqueConstraintData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
Columns string
|
||||
}
|
||||
|
||||
// CreateCheckConstraintData contains data for create check constraint template
|
||||
type CreateCheckConstraintData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
Expression string
|
||||
}
|
||||
|
||||
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
|
||||
type CreateForeignKeyWithCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
SourceColumns string
|
||||
TargetSchema string
|
||||
TargetTable string
|
||||
TargetColumns string
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
Deferrable bool
|
||||
}
|
||||
|
||||
// SetSequenceValueData contains data for set sequence value template
|
||||
type SetSequenceValueData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
SequenceName string
|
||||
ColumnName string
|
||||
}
|
||||
|
||||
// CreateSequenceData contains data for create sequence template
|
||||
type CreateSequenceData struct {
|
||||
SchemaName string
|
||||
SequenceName string
|
||||
Increment int
|
||||
MinValue int64
|
||||
MaxValue int64
|
||||
StartValue int64
|
||||
CacheSize int
|
||||
}
|
||||
|
||||
// AddColumnWithCheckData contains data for add column with existence check template
|
||||
type AddColumnWithCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
ColumnDefinition string
|
||||
}
|
||||
|
||||
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
|
||||
type CreatePrimaryKeyWithAutoGenCheckData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
|
||||
Columns string
|
||||
}
|
||||
|
||||
// Execute methods for each template
|
||||
|
||||
// ExecuteCreateTable executes the create table template
|
||||
@@ -319,6 +385,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateUniqueConstraint executes the create unique constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateCheckConstraint executes the create check constraint template
|
||||
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
|
||||
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteSetSequenceValue executes the set sequence value template
|
||||
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreateSequence executes the create sequence template
|
||||
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteAddColumnWithCheck executes the add column with check template
|
||||
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
|
||||
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// Helper functions to build template data from models
|
||||
|
||||
// BuildCreateTableData builds CreateTableData from a models.Table
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
|
||||
{{- if .Default}} DEFAULT {{.Default}}{{end}}
|
||||
{{- if .NotNull}} NOT NULL{{end}};
|
||||
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
12
pkg/writers/pgsql/templates/add_column_with_check.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND column_name = '{{.ColumnName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD COLUMN {{.ColumnDefinition}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,7 +1,7 @@
|
||||
{{- if .SetDefault -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
|
||||
{{- else -}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
|
||||
{{- end -}}
|
||||
@@ -1,2 +1,2 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
|
||||
COMMENT ON COLUMN {{quote_ident .SchemaName}}.{{quote_ident .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';
|
||||
@@ -1 +1 @@
|
||||
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
|
||||
COMMENT ON TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} IS '{{.Comment}}';
|
||||
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_check_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,10 +1,10 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}}
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
|
||||
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}
|
||||
DEFERRABLE;
|
||||
@@ -0,0 +1,18 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}}
|
||||
FOREIGN KEY ({{.SourceColumns}})
|
||||
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
|
||||
ON DELETE {{.OnDelete}}
|
||||
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
|
||||
DEFERRABLE{{end}};
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1,2 +1,2 @@
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
|
||||
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
|
||||
ON {{quote_ident .SchemaName}}.{{quote_ident .TableName}} USING {{.IndexType}} ({{.Columns}});
|
||||
@@ -6,8 +6,8 @@ BEGIN
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -0,0 +1,27 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
auto_pk_name text;
|
||||
BEGIN
|
||||
-- Drop auto-generated primary key if it exists
|
||||
SELECT constraint_name INTO auto_pk_name
|
||||
FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_type = 'PRIMARY KEY'
|
||||
AND constraint_name IN ({{.AutoGenNames}});
|
||||
|
||||
IF auto_pk_name IS NOT NULL THEN
|
||||
EXECUTE 'ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||
END IF;
|
||||
|
||||
-- Add named primary key if it doesn't exist
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
6
pkg/writers/pgsql/templates/create_sequence.tmpl
Normal file
@@ -0,0 +1,6 @@
|
||||
CREATE SEQUENCE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}
|
||||
INCREMENT {{.Increment}}
|
||||
MINVALUE {{.MinValue}}
|
||||
MAXVALUE {{.MaxValue}}
|
||||
START {{.StartValue}}
|
||||
CACHE {{.CacheSize}};
|
||||
@@ -1,7 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
|
||||
CREATE TABLE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .TableName}} (
|
||||
{{- range $i, $col := .Columns}}
|
||||
{{- if $i}},{{end}}
|
||||
{{$col.Name}} {{$col.Type}}
|
||||
{{quote_ident $col.Name}} {{$col.Type}}
|
||||
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
||||
{{- if $col.NotNull}} NOT NULL{{end}}
|
||||
{{- end}}
|
||||
|
||||
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
12
pkg/writers/pgsql/templates/create_unique_constraint.tmpl
Normal file
@@ -0,0 +1,12 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE table_schema = '{{.SchemaName}}'
|
||||
AND table_name = '{{.TableName}}'
|
||||
AND constraint_name = '{{.ConstraintName}}'
|
||||
) THEN
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -1 +1 @@
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
|
||||
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
|
||||
@@ -1 +1 @@
|
||||
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
|
||||
DROP INDEX IF EXISTS {{quote_ident .SchemaName}}.{{quote_ident .IndexName}} CASCADE;
|
||||
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
19
pkg/writers/pgsql/templates/set_sequence_value.tmpl
Normal file
@@ -0,0 +1,19 @@
|
||||
DO $$
|
||||
DECLARE
|
||||
m_cnt bigint;
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM pg_class c
|
||||
INNER JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relname = '{{.SequenceName}}'
|
||||
AND n.nspname = '{{.SchemaName}}'
|
||||
AND c.relkind = 'S'
|
||||
) THEN
|
||||
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
|
||||
FROM {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
|
||||
INTO m_cnt;
|
||||
|
||||
PERFORM setval('{{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}'::regclass, m_cnt);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -22,6 +22,7 @@ type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
writer io.Writer
|
||||
executionReport *ExecutionReport
|
||||
executor *TemplateExecutor
|
||||
}
|
||||
|
||||
// ExecutionReport tracks the execution status of SQL statements
|
||||
@@ -57,8 +58,10 @@ type ExecutionError struct {
|
||||
|
||||
// NewWriter creates a new PostgreSQL SQL writer
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
executor, _ := NewTemplateExecutor()
|
||||
return &Writer{
|
||||
options: options,
|
||||
options: options,
|
||||
executor: executor,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,36 +218,19 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||
}
|
||||
|
||||
// Wrap in DO block to drop auto-generated PK and add our named PK
|
||||
stmt := fmt.Sprintf("DO $$\nDECLARE\n"+
|
||||
" auto_pk_name text;\n"+
|
||||
"BEGIN\n"+
|
||||
" -- Drop auto-generated primary key if it exists\n"+
|
||||
" SELECT constraint_name INTO auto_pk_name\n"+
|
||||
" FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_type = 'PRIMARY KEY'\n"+
|
||||
" AND constraint_name IN (%s);\n"+
|
||||
"\n"+
|
||||
" IF auto_pk_name IS NOT NULL THEN\n"+
|
||||
" EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n"+
|
||||
" END IF;\n"+
|
||||
"\n"+
|
||||
" -- Add named primary key if it doesn't exist\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s);\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, formatStringList(autoGenPKNames),
|
||||
schema.SQLName(), table.SQLName(),
|
||||
schema.Name, table.Name, pkName,
|
||||
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
|
||||
// Use template to generate primary key statement
|
||||
data := CreatePrimaryKeyWithAutoGenCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: pkName,
|
||||
AutoGenNames: formatStringList(autoGenPKNames),
|
||||
Columns: strings.Join(pkColumns, ", "),
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -290,7 +276,7 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
||||
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
uniqueStr, quoteIdentifier(index.Name), schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -302,20 +288,41 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
continue
|
||||
}
|
||||
|
||||
// Wrap in DO block to check for existing constraint
|
||||
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD CONSTRAINT %s UNIQUE (%s);\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, constraint.Name,
|
||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
||||
strings.Join(constraint.Columns, ", "))
|
||||
// Use template to generate unique constraint statement
|
||||
data := CreateUniqueConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Columns: strings.Join(constraint.Columns, ", "),
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreateUniqueConstraint(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate unique constraint for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5.7: Check constraints
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.CheckConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use template to generate check constraint statement
|
||||
data := CreateCheckConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Expression: constraint.Expression,
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreateCheckConstraint(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate check constraint for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -342,23 +349,24 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
|
||||
onUpdate = "NO ACTION"
|
||||
}
|
||||
|
||||
// Wrap in DO block to check for existing constraint
|
||||
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.table_constraints\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND constraint_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s;\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, constraint.Name,
|
||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
||||
strings.Join(constraint.Columns, ", "),
|
||||
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
||||
strings.Join(constraint.ReferencedColumns, ", "),
|
||||
onDelete, onUpdate)
|
||||
// Use template to generate foreign key statement
|
||||
data := CreateForeignKeyWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
SourceColumns: strings.Join(constraint.Columns, ", "),
|
||||
TargetSchema: refSchema,
|
||||
TargetTable: constraint.ReferencedTable,
|
||||
TargetColumns: strings.Join(constraint.ReferencedColumns, ", "),
|
||||
OnDelete: onDelete,
|
||||
OnUpdate: onUpdate,
|
||||
Deferrable: false,
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -406,19 +414,18 @@ func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, e
|
||||
for _, col := range columns {
|
||||
colDef := w.generateColumnDefinition(col)
|
||||
|
||||
// Generate DO block that checks if column exists before adding
|
||||
stmt := fmt.Sprintf("DO $$\nBEGIN\n"+
|
||||
" IF NOT EXISTS (\n"+
|
||||
" SELECT 1 FROM information_schema.columns\n"+
|
||||
" WHERE table_schema = '%s'\n"+
|
||||
" AND table_name = '%s'\n"+
|
||||
" AND column_name = '%s'\n"+
|
||||
" ) THEN\n"+
|
||||
" ALTER TABLE %s.%s ADD COLUMN %s;\n"+
|
||||
" END IF;\n"+
|
||||
"END;\n$$",
|
||||
schema.Name, table.Name, col.Name,
|
||||
schema.SQLName(), table.SQLName(), colDef)
|
||||
// Use template to generate add column statement
|
||||
data := AddColumnWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ColumnName: col.Name,
|
||||
ColumnDefinition: colDef,
|
||||
}
|
||||
|
||||
stmt, err := w.executor.ExecuteAddColumnWithCheck(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||
}
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
@@ -572,6 +579,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 5.7: Create check constraints (priority 190)
|
||||
if err := w.writeCheckConstraints(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 6: Create foreign key constraints (priority 195)
|
||||
if err := w.writeForeignKeys(schema); err != nil {
|
||||
return err
|
||||
@@ -669,13 +681,23 @@ func (w *Writer) writeSequences(schema *models.Schema) error {
|
||||
}
|
||||
|
||||
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
||||
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
|
||||
schema.SQLName(), seqName)
|
||||
fmt.Fprintf(w.writer, " INCREMENT 1\n")
|
||||
fmt.Fprintf(w.writer, " MINVALUE 1\n")
|
||||
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
|
||||
fmt.Fprintf(w.writer, " START 1\n")
|
||||
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
|
||||
|
||||
data := CreateSequenceData{
|
||||
SchemaName: schema.Name,
|
||||
SequenceName: seqName,
|
||||
Increment: 1,
|
||||
MinValue: 1,
|
||||
MaxValue: 9223372036854775807,
|
||||
StartValue: 1,
|
||||
CacheSize: 1,
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteCreateSequence(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate create sequence for %s.%s: %w", schema.Name, seqName, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -717,18 +739,19 @@ func (w *Writer) writeAddColumns(schema *models.Schema) error {
|
||||
for _, col := range columns {
|
||||
colDef := w.generateColumnDefinition(col)
|
||||
|
||||
// Generate DO block that checks if column exists before adding
|
||||
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.columns\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND column_name = '%s'\n", col.Name)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD COLUMN %s;\n",
|
||||
schema.SQLName(), table.SQLName(), colDef)
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
data := AddColumnWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ColumnName: col.Name,
|
||||
ColumnDefinition: colDef,
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteAddColumnWithCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -782,37 +805,20 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
||||
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "DO $$\nDECLARE\n")
|
||||
fmt.Fprintf(w.writer, " auto_pk_name text;\nBEGIN\n")
|
||||
data := CreatePrimaryKeyWithAutoGenCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: pkName,
|
||||
AutoGenNames: formatStringList(autoGenPKNames),
|
||||
Columns: strings.Join(columnNames, ", "),
|
||||
}
|
||||
|
||||
// Check for and drop auto-generated primary keys
|
||||
fmt.Fprintf(w.writer, " -- Drop auto-generated primary key if it exists\n")
|
||||
fmt.Fprintf(w.writer, " SELECT constraint_name INTO auto_pk_name\n")
|
||||
fmt.Fprintf(w.writer, " FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_type = 'PRIMARY KEY'\n")
|
||||
fmt.Fprintf(w.writer, " AND constraint_name IN (%s);\n", formatStringList(autoGenPKNames))
|
||||
fmt.Fprintf(w.writer, "\n")
|
||||
fmt.Fprintf(w.writer, " IF auto_pk_name IS NOT NULL THEN\n")
|
||||
fmt.Fprintf(w.writer, " EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n",
|
||||
schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "\n")
|
||||
|
||||
// Add our named primary key if it doesn't exist
|
||||
fmt.Fprintf(w.writer, " -- Add named primary key if it doesn't exist\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
|
||||
pkName, strings.Join(columnNames, ", "))
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -924,20 +930,56 @@ func (w *Writer) writeUniqueConstraints(schema *models.Schema) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Wrap in DO block to check for existing constraint
|
||||
fmt.Fprintf(w.writer, "DO $$\n")
|
||||
fmt.Fprintf(w.writer, "BEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", constraint.Name)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s ADD CONSTRAINT %s UNIQUE (%s);\n",
|
||||
schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(columnExprs, ", "))
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n")
|
||||
fmt.Fprintf(w.writer, "$$;\n\n")
|
||||
sql, err := w.executor.ExecuteCreateUniqueConstraint(CreateUniqueConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Columns: strings.Join(columnExprs, ", "),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate unique constraint: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "%s\n\n", sql)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCheckConstraints generates ALTER TABLE statements for check constraints
|
||||
func (w *Writer) writeCheckConstraints(schema *models.Schema) error {
|
||||
fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name)
|
||||
|
||||
for _, table := range schema.Tables {
|
||||
// Sort constraints by name for consistent output
|
||||
constraintNames := make([]string, 0, len(table.Constraints))
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.CheckConstraint {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Skip if expression is empty
|
||||
if constraint.Expression == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
sql, err := w.executor.ExecuteCreateCheckConstraint(CreateCheckConstraintData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
Expression: constraint.Expression,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate check constraint: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w.writer, "%s\n\n", sql)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1021,24 +1063,103 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
||||
refTable = rel.ToTable
|
||||
}
|
||||
|
||||
// Use DO block to check if constraint exists before adding
|
||||
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
||||
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
||||
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", fkName)
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
|
||||
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
||||
refSchema, refTable, strings.Join(targetColumns, ", "))
|
||||
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
||||
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
|
||||
fmt.Fprintf(w.writer, " DEFERRABLE;\n")
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
||||
// Use template executor to generate foreign key with existence check
|
||||
data := CreateForeignKeyWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: fkName,
|
||||
SourceColumns: strings.Join(sourceColumns, ", "),
|
||||
TargetSchema: refSchema,
|
||||
TargetTable: refTable,
|
||||
TargetColumns: strings.Join(targetColumns, ", "),
|
||||
OnDelete: onDelete,
|
||||
OnUpdate: onUpdate,
|
||||
Deferrable: true,
|
||||
}
|
||||
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
}
|
||||
|
||||
// Also process any foreign key constraints that don't have a relationship
|
||||
processedConstraints := make(map[string]bool)
|
||||
for _, rel := range table.Relationships {
|
||||
fkName := rel.ForeignKey
|
||||
if fkName == "" {
|
||||
fkName = rel.Name
|
||||
}
|
||||
if fkName != "" {
|
||||
processedConstraints[fkName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Find unprocessed foreign key constraints
|
||||
constraintNames := make([]string, 0)
|
||||
for name, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint && !processedConstraints[name] {
|
||||
constraintNames = append(constraintNames, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(constraintNames)
|
||||
|
||||
for _, name := range constraintNames {
|
||||
constraint := table.Constraints[name]
|
||||
|
||||
// Build column lists
|
||||
sourceColumns := make([]string, 0, len(constraint.Columns))
|
||||
for _, colName := range constraint.Columns {
|
||||
if col, ok := table.Columns[colName]; ok {
|
||||
sourceColumns = append(sourceColumns, col.SQLName())
|
||||
} else {
|
||||
sourceColumns = append(sourceColumns, colName)
|
||||
}
|
||||
}
|
||||
|
||||
targetColumns := make([]string, 0, len(constraint.ReferencedColumns))
|
||||
for _, colName := range constraint.ReferencedColumns {
|
||||
targetColumns = append(targetColumns, strings.ToLower(colName))
|
||||
}
|
||||
|
||||
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
onDelete := "NO ACTION"
|
||||
if constraint.OnDelete != "" {
|
||||
onDelete = strings.ToUpper(constraint.OnDelete)
|
||||
}
|
||||
|
||||
onUpdate := "NO ACTION"
|
||||
if constraint.OnUpdate != "" {
|
||||
onUpdate = strings.ToUpper(constraint.OnUpdate)
|
||||
}
|
||||
|
||||
refSchema := constraint.ReferencedSchema
|
||||
if refSchema == "" {
|
||||
refSchema = schema.Name
|
||||
}
|
||||
refTable := constraint.ReferencedTable
|
||||
|
||||
// Use template executor to generate foreign key with existence check
|
||||
data := CreateForeignKeyWithCheckData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
ConstraintName: constraint.Name,
|
||||
SourceColumns: strings.Join(sourceColumns, ", "),
|
||||
TargetSchema: refSchema,
|
||||
TargetTable: refTable,
|
||||
TargetColumns: strings.Join(targetColumns, ", "),
|
||||
OnDelete: onDelete,
|
||||
OnUpdate: onUpdate,
|
||||
Deferrable: false,
|
||||
}
|
||||
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1057,26 +1178,19 @@ func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
|
||||
|
||||
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
||||
|
||||
fmt.Fprintf(w.writer, "DO $$\n")
|
||||
fmt.Fprintf(w.writer, "DECLARE\n")
|
||||
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
|
||||
fmt.Fprintf(w.writer, "BEGIN\n")
|
||||
fmt.Fprintf(w.writer, " IF EXISTS (\n")
|
||||
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
|
||||
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
|
||||
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
|
||||
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
|
||||
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
|
||||
fmt.Fprintf(w.writer, " ) THEN\n")
|
||||
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
|
||||
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
|
||||
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
|
||||
fmt.Fprintf(w.writer, " \n")
|
||||
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
|
||||
schema.SQLName(), seqName)
|
||||
fmt.Fprintf(w.writer, " END IF;\n")
|
||||
fmt.Fprintf(w.writer, "END;\n")
|
||||
fmt.Fprintf(w.writer, "$$;\n\n")
|
||||
// Use template executor to generate set sequence value statement
|
||||
data := SetSequenceValueData{
|
||||
SchemaName: schema.Name,
|
||||
TableName: table.Name,
|
||||
SequenceName: seqName,
|
||||
ColumnName: pk.Name,
|
||||
}
|
||||
sql, err := w.executor.ExecuteSetSequenceValue(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate set sequence value for %s.%s: %w", schema.Name, table.Name, err)
|
||||
}
|
||||
fmt.Fprint(w.writer, sql)
|
||||
fmt.Fprint(w.writer, "\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1296,7 +1410,8 @@ func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) erro
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements))
|
||||
stmtType := detectStatementType(stmtTrimmed)
|
||||
fmt.Fprintf(os.Stderr, "Executing statement %d/%d [%s]...\n", i+1, len(statements), stmtType)
|
||||
|
||||
_, execErr := conn.Exec(ctx, stmt)
|
||||
if execErr != nil {
|
||||
@@ -1430,3 +1545,94 @@ func truncateStatement(stmt string) string {
|
||||
func getCurrentTimestamp() string {
|
||||
return time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// detectStatementType detects the type of SQL statement for logging
|
||||
func detectStatementType(stmt string) string {
|
||||
upperStmt := strings.ToUpper(stmt)
|
||||
|
||||
// Check for DO blocks (used for conditional DDL)
|
||||
if strings.HasPrefix(upperStmt, "DO $$") || strings.HasPrefix(upperStmt, "DO $") {
|
||||
// Look inside the DO block for the actual operation
|
||||
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD CONSTRAINT") {
|
||||
if strings.Contains(upperStmt, "UNIQUE") {
|
||||
return "ADD UNIQUE CONSTRAINT"
|
||||
} else if strings.Contains(upperStmt, "FOREIGN KEY") {
|
||||
return "ADD FOREIGN KEY"
|
||||
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
|
||||
return "ADD PRIMARY KEY"
|
||||
} else if strings.Contains(upperStmt, "CHECK") {
|
||||
return "ADD CHECK CONSTRAINT"
|
||||
}
|
||||
return "ADD CONSTRAINT"
|
||||
}
|
||||
if strings.Contains(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD COLUMN") {
|
||||
return "ADD COLUMN"
|
||||
}
|
||||
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
|
||||
return "DROP CONSTRAINT"
|
||||
}
|
||||
return "DO BLOCK"
|
||||
}
|
||||
|
||||
// Direct DDL statements
|
||||
if strings.HasPrefix(upperStmt, "CREATE SCHEMA") {
|
||||
return "CREATE SCHEMA"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE SEQUENCE") {
|
||||
return "CREATE SEQUENCE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE TABLE") {
|
||||
return "CREATE TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE INDEX") {
|
||||
return "CREATE INDEX"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "CREATE UNIQUE INDEX") {
|
||||
return "CREATE UNIQUE INDEX"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "ALTER TABLE") {
|
||||
if strings.Contains(upperStmt, "ADD CONSTRAINT") {
|
||||
if strings.Contains(upperStmt, "FOREIGN KEY") {
|
||||
return "ADD FOREIGN KEY"
|
||||
} else if strings.Contains(upperStmt, "PRIMARY KEY") {
|
||||
return "ADD PRIMARY KEY"
|
||||
} else if strings.Contains(upperStmt, "UNIQUE") {
|
||||
return "ADD UNIQUE CONSTRAINT"
|
||||
} else if strings.Contains(upperStmt, "CHECK") {
|
||||
return "ADD CHECK CONSTRAINT"
|
||||
}
|
||||
return "ADD CONSTRAINT"
|
||||
}
|
||||
if strings.Contains(upperStmt, "ADD COLUMN") {
|
||||
return "ADD COLUMN"
|
||||
}
|
||||
if strings.Contains(upperStmt, "DROP CONSTRAINT") {
|
||||
return "DROP CONSTRAINT"
|
||||
}
|
||||
if strings.Contains(upperStmt, "ALTER COLUMN") {
|
||||
return "ALTER COLUMN"
|
||||
}
|
||||
return "ALTER TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "COMMENT ON TABLE") {
|
||||
return "COMMENT ON TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "COMMENT ON COLUMN") {
|
||||
return "COMMENT ON COLUMN"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "DROP TABLE") {
|
||||
return "DROP TABLE"
|
||||
}
|
||||
if strings.HasPrefix(upperStmt, "DROP INDEX") {
|
||||
return "DROP INDEX"
|
||||
}
|
||||
|
||||
// Default
|
||||
return "SQL"
|
||||
}
|
||||
|
||||
// quoteIdentifier wraps an identifier in double quotes if necessary
|
||||
// This is needed for identifiers that start with numbers or contain special characters
|
||||
func quoteIdentifier(s string) string {
|
||||
return quoteIdent(s)
|
||||
}
|
||||
|
||||
@@ -234,6 +234,226 @@ func TestWriteUniqueConstraints(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteCheckConstraints(t *testing.T) {
|
||||
// Create a test database with check constraints
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create table with check constraints
|
||||
table := models.InitTable("products", "public")
|
||||
|
||||
// Add columns
|
||||
priceCol := models.InitColumn("price", "products", "public")
|
||||
priceCol.Type = "numeric(10,2)"
|
||||
table.Columns["price"] = priceCol
|
||||
|
||||
statusCol := models.InitColumn("status", "products", "public")
|
||||
statusCol.Type = "varchar(20)"
|
||||
table.Columns["status"] = statusCol
|
||||
|
||||
quantityCol := models.InitColumn("quantity", "products", "public")
|
||||
quantityCol.Type = "integer"
|
||||
table.Columns["quantity"] = quantityCol
|
||||
|
||||
// Add check constraints
|
||||
priceConstraint := &models.Constraint{
|
||||
Name: "ck_price_positive",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "price >= 0",
|
||||
}
|
||||
table.Constraints["ck_price_positive"] = priceConstraint
|
||||
|
||||
statusConstraint := &models.Constraint{
|
||||
Name: "ck_status_valid",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "status IN ('active', 'inactive', 'discontinued')",
|
||||
}
|
||||
table.Constraints["ck_status_valid"] = statusConstraint
|
||||
|
||||
quantityConstraint := &models.Constraint{
|
||||
Name: "ck_quantity_nonnegative",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "products",
|
||||
Expression: "quantity >= 0",
|
||||
}
|
||||
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify check constraints are present
|
||||
if !strings.Contains(output, "-- Check constraints for schema: public") {
|
||||
t.Errorf("Output missing check constraints header")
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
|
||||
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
|
||||
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
|
||||
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAllConstraintTypes(t *testing.T) {
|
||||
// Create a comprehensive test with all constraint types
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create orders table
|
||||
ordersTable := models.InitTable("orders", "public")
|
||||
|
||||
// Add columns
|
||||
idCol := models.InitColumn("id", "orders", "public")
|
||||
idCol.Type = "integer"
|
||||
idCol.IsPrimaryKey = true
|
||||
ordersTable.Columns["id"] = idCol
|
||||
|
||||
userIDCol := models.InitColumn("user_id", "orders", "public")
|
||||
userIDCol.Type = "integer"
|
||||
userIDCol.NotNull = true
|
||||
ordersTable.Columns["user_id"] = userIDCol
|
||||
|
||||
orderNumberCol := models.InitColumn("order_number", "orders", "public")
|
||||
orderNumberCol.Type = "varchar(50)"
|
||||
orderNumberCol.NotNull = true
|
||||
ordersTable.Columns["order_number"] = orderNumberCol
|
||||
|
||||
totalCol := models.InitColumn("total", "orders", "public")
|
||||
totalCol.Type = "numeric(10,2)"
|
||||
ordersTable.Columns["total"] = totalCol
|
||||
|
||||
statusCol := models.InitColumn("status", "orders", "public")
|
||||
statusCol.Type = "varchar(20)"
|
||||
ordersTable.Columns["status"] = statusCol
|
||||
|
||||
// Add primary key constraint
|
||||
pkConstraint := &models.Constraint{
|
||||
Name: "pk_orders",
|
||||
Type: models.PrimaryKeyConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"id"},
|
||||
}
|
||||
ordersTable.Constraints["pk_orders"] = pkConstraint
|
||||
|
||||
// Add unique constraint
|
||||
uniqueConstraint := &models.Constraint{
|
||||
Name: "uq_order_number",
|
||||
Type: models.UniqueConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"order_number"},
|
||||
}
|
||||
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
|
||||
|
||||
// Add check constraint
|
||||
checkConstraint := &models.Constraint{
|
||||
Name: "ck_total_positive",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Expression: "total > 0",
|
||||
}
|
||||
ordersTable.Constraints["ck_total_positive"] = checkConstraint
|
||||
|
||||
statusCheckConstraint := &models.Constraint{
|
||||
Name: "ck_status_valid",
|
||||
Type: models.CheckConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Expression: "status IN ('pending', 'completed', 'cancelled')",
|
||||
}
|
||||
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
|
||||
|
||||
// Add foreign key constraint (referencing a users table)
|
||||
fkConstraint := &models.Constraint{
|
||||
Name: "fk_orders_user",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Schema: "public",
|
||||
Table: "orders",
|
||||
Columns: []string{"user_id"},
|
||||
ReferencedSchema: "public",
|
||||
ReferencedTable: "users",
|
||||
ReferencedColumns: []string{"id"},
|
||||
OnDelete: "CASCADE",
|
||||
OnUpdate: "CASCADE",
|
||||
}
|
||||
ordersTable.Constraints["fk_orders_user"] = fkConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, ordersTable)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify all constraint types are present
|
||||
expectedConstraints := map[string]string{
|
||||
"Primary Key": "PRIMARY KEY",
|
||||
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
|
||||
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
|
||||
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
|
||||
"Foreign Key": "FOREIGN KEY",
|
||||
}
|
||||
|
||||
for name, expected := range expectedConstraints {
|
||||
if !strings.Contains(output, expected) {
|
||||
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify section headers
|
||||
sections := []string{
|
||||
"-- Primary keys for schema: public",
|
||||
"-- Unique constraints for schema: public",
|
||||
"-- Check constraints for schema: public",
|
||||
"-- Foreign keys for schema: public",
|
||||
}
|
||||
|
||||
for _, section := range sections {
|
||||
if !strings.Contains(output, section) {
|
||||
t.Errorf("Output missing section header: %s", section)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTable(t *testing.T) {
|
||||
// Create a single table
|
||||
table := models.InitTable("products", "public")
|
||||
|
||||
Reference in New Issue
Block a user