test(pgsql, reflectutil): ✨ add comprehensive test coverage
All checks were successful
All checks were successful
* Introduce tests for PostgreSQL data types and keywords. * Implement tests for reflect utility functions. * Ensure consistency and correctness of type conversions and keyword mappings. * Validate behavior for various edge cases and input types.
This commit is contained in:
714
pkg/commontypes/commontypes_test.go
Normal file
714
pkg/commontypes/commontypes_test.go
Normal file
@@ -0,0 +1,714 @@
|
|||||||
|
package commontypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractBaseType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"varchar with length", "varchar(100)", "varchar"},
|
||||||
|
{"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"},
|
||||||
|
{"numeric with precision", "numeric(10,2)", "numeric"},
|
||||||
|
{"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"},
|
||||||
|
{"decimal with precision", "decimal(15,3)", "decimal"},
|
||||||
|
{"char with length", "char(50)", "char"},
|
||||||
|
{"simple integer", "integer", "integer"},
|
||||||
|
{"simple text", "text", "text"},
|
||||||
|
{"bigint", "bigint", "bigint"},
|
||||||
|
{"With spaces", " varchar(100) ", "varchar"},
|
||||||
|
{"No parentheses", "boolean", "boolean"},
|
||||||
|
{"Empty string", "", ""},
|
||||||
|
{"Mixed case", "VarChar(100)", "varchar"},
|
||||||
|
{"timestamp with time zone", "timestamp(6) with time zone", "timestamp"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ExtractBaseType(tt.sqlType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeType(t *testing.T) {
|
||||||
|
// NormalizeType is an alias for ExtractBaseType, test that they behave the same
|
||||||
|
testCases := []string{
|
||||||
|
"varchar(100)",
|
||||||
|
"numeric(10,2)",
|
||||||
|
"integer",
|
||||||
|
"text",
|
||||||
|
" VARCHAR(255) ",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc, func(t *testing.T) {
|
||||||
|
extracted := ExtractBaseType(tc)
|
||||||
|
normalized := NormalizeType(tc)
|
||||||
|
if extracted != normalized {
|
||||||
|
t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q",
|
||||||
|
tc, extracted, tc, normalized)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToGo(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "int32"},
|
||||||
|
{"bigint nullable", "bigint", true, "int64"},
|
||||||
|
{"smallint nullable", "smallint", true, "int16"},
|
||||||
|
{"serial nullable", "serial", true, "int32"},
|
||||||
|
|
||||||
|
// Integer types (not nullable)
|
||||||
|
{"integer not nullable", "integer", false, "*int32"},
|
||||||
|
{"bigint not nullable", "bigint", false, "*int64"},
|
||||||
|
{"smallint not nullable", "smallint", false, "*int16"},
|
||||||
|
|
||||||
|
// String types (nullable)
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar with length nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// String types (not nullable)
|
||||||
|
{"text not nullable", "text", false, "*string"},
|
||||||
|
{"varchar not nullable", "varchar", false, "*string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "*bool"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "float32"},
|
||||||
|
{"double precision nullable", "double precision", true, "float64"},
|
||||||
|
{"real not nullable", "real", false, "*float32"},
|
||||||
|
{"double precision not nullable", "double precision", false, "*float64"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "time.Time"},
|
||||||
|
{"date nullable", "date", true, "time.Time"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "*time.Time"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "[]byte"},
|
||||||
|
{"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "string"},
|
||||||
|
{"uuid not nullable", "uuid", false, "*string"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "string"},
|
||||||
|
{"jsonb nullable", "jsonb", true, "string"},
|
||||||
|
|
||||||
|
// Array
|
||||||
|
{"array nullable", "array", true, "[]string"},
|
||||||
|
{"array not nullable", "array", false, "[]string"}, // Slices don't get pointer
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "interface{}"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToGo(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToTypeScript(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer nullable", "integer", true, "number"},
|
||||||
|
{"integer not nullable", "integer", false, "number | null"},
|
||||||
|
{"bigint nullable", "bigint", true, "number"},
|
||||||
|
{"bigint not nullable", "bigint", false, "number | null"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"text not nullable", "text", false, "string | null"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "boolean"},
|
||||||
|
{"boolean not nullable", "boolean", false, "boolean | null"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "number"},
|
||||||
|
{"double precision nullable", "double precision", true, "number"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "Date"},
|
||||||
|
{"date nullable", "date", true, "Date"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "Date | null"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "Buffer"},
|
||||||
|
{"bytea not nullable", "bytea", false, "Buffer | null"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "any"},
|
||||||
|
{"jsonb nullable", "jsonb", true, "any"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "string"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "any"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "any | null"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToTypeScript(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToPython(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer", "integer", "int"},
|
||||||
|
{"bigint", "bigint", "int"},
|
||||||
|
{"smallint", "smallint", "int"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text", "text", "str"},
|
||||||
|
{"varchar", "varchar", "str"},
|
||||||
|
{"varchar(100)", "varchar(100)", "str"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean", "boolean", "bool"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real", "real", "float"},
|
||||||
|
{"double precision", "double precision", "float"},
|
||||||
|
{"numeric", "numeric", "Decimal"},
|
||||||
|
{"decimal", "decimal", "Decimal"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp", "timestamp", "datetime"},
|
||||||
|
{"date", "date", "date"},
|
||||||
|
{"time", "time", "time"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea", "bytea", "bytes"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json", "json", "dict"},
|
||||||
|
{"jsonb", "jsonb", "dict"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid", "uuid", "UUID"},
|
||||||
|
|
||||||
|
// Array
|
||||||
|
{"array", "array", "list"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type", "unknowntype", "Any"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToPython(tt.sqlType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToCSharp(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "int"},
|
||||||
|
{"bigint nullable", "bigint", true, "long"},
|
||||||
|
{"smallint nullable", "smallint", true, "short"},
|
||||||
|
|
||||||
|
// Integer types (not nullable - value types get ?)
|
||||||
|
{"integer not nullable", "integer", false, "int?"},
|
||||||
|
{"bigint not nullable", "bigint", false, "long?"},
|
||||||
|
{"smallint not nullable", "smallint", false, "short?"},
|
||||||
|
|
||||||
|
// String types (reference types, no ? needed)
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"text not nullable", "text", false, "string"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "bool?"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "float"},
|
||||||
|
{"double precision nullable", "double precision", true, "double"},
|
||||||
|
{"decimal nullable", "decimal", true, "decimal"},
|
||||||
|
{"real not nullable", "real", false, "float?"},
|
||||||
|
{"double precision not nullable", "double precision", false, "double?"},
|
||||||
|
{"decimal not nullable", "decimal", false, "decimal?"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "DateTime"},
|
||||||
|
{"date nullable", "date", true, "DateTime"},
|
||||||
|
{"timestamptz nullable", "timestamptz", true, "DateTimeOffset"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "DateTime?"},
|
||||||
|
{"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"},
|
||||||
|
|
||||||
|
// Binary (array type, no ?)
|
||||||
|
{"bytea nullable", "bytea", true, "byte[]"},
|
||||||
|
{"bytea not nullable", "bytea", false, "byte[]"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "Guid"},
|
||||||
|
{"uuid not nullable", "uuid", false, "Guid?"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "string"},
|
||||||
|
|
||||||
|
// Unknown types (object is reference type)
|
||||||
|
{"unknown type nullable", "unknowntype", true, "object"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToCSharp(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNeedsTimeImport(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
goType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"time.Time type", "time.Time", true},
|
||||||
|
{"pointer to time.Time", "*time.Time", true},
|
||||||
|
{"int32 type", "int32", false},
|
||||||
|
{"string type", "string", false},
|
||||||
|
{"bool type", "bool", false},
|
||||||
|
{"[]byte type", "[]byte", false},
|
||||||
|
{"interface{}", "interface{}", false},
|
||||||
|
{"empty string", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := NeedsTimeImport(tt.goType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoTypeMap(t *testing.T) {
|
||||||
|
// Test that the map contains expected entries
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int32",
|
||||||
|
"bigint": "int64",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "float64",
|
||||||
|
"bytea": "[]byte",
|
||||||
|
"timestamp": "time.Time",
|
||||||
|
"uuid": "string",
|
||||||
|
"json": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedGoType := range expectedMappings {
|
||||||
|
if goType, ok := GoTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("GoTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if goType != expectedGoType {
|
||||||
|
t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(GoTypeMap) == 0 {
|
||||||
|
t.Error("GoTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeScriptTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "number",
|
||||||
|
"bigint": "number",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "boolean",
|
||||||
|
"double precision": "number",
|
||||||
|
"bytea": "Buffer",
|
||||||
|
"timestamp": "Date",
|
||||||
|
"uuid": "string",
|
||||||
|
"json": "any",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedTSType := range expectedMappings {
|
||||||
|
if tsType, ok := TypeScriptTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if tsType != expectedTSType {
|
||||||
|
t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(TypeScriptTypeMap) == 0 {
|
||||||
|
t.Error("TypeScriptTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPythonTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"bigint": "int",
|
||||||
|
"text": "str",
|
||||||
|
"boolean": "bool",
|
||||||
|
"real": "float",
|
||||||
|
"numeric": "Decimal",
|
||||||
|
"bytea": "bytes",
|
||||||
|
"date": "date",
|
||||||
|
"uuid": "UUID",
|
||||||
|
"json": "dict",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedPyType := range expectedMappings {
|
||||||
|
if pyType, ok := PythonTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("PythonTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if pyType != expectedPyType {
|
||||||
|
t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(PythonTypeMap) == 0 {
|
||||||
|
t.Error("PythonTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSharpTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"bigint": "long",
|
||||||
|
"smallint": "short",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "double",
|
||||||
|
"decimal": "decimal",
|
||||||
|
"bytea": "byte[]",
|
||||||
|
"timestamp": "DateTime",
|
||||||
|
"uuid": "Guid",
|
||||||
|
"json": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedCSType := range expectedMappings {
|
||||||
|
if csType, ok := CSharpTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("CSharpTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if csType != expectedCSType {
|
||||||
|
t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(CSharpTypeMap) == 0 {
|
||||||
|
t.Error("CSharpTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToJava(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer nullable", "integer", true, "Integer"},
|
||||||
|
{"integer not nullable", "integer", false, "Integer"},
|
||||||
|
{"bigint nullable", "bigint", true, "Long"},
|
||||||
|
{"smallint nullable", "smallint", true, "Short"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "String"},
|
||||||
|
{"varchar nullable", "varchar", true, "String"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "Boolean"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "Float"},
|
||||||
|
{"double precision nullable", "double precision", true, "Double"},
|
||||||
|
{"numeric nullable", "numeric", true, "BigDecimal"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "Timestamp"},
|
||||||
|
{"date nullable", "date", true, "Date"},
|
||||||
|
{"time nullable", "time", true, "Time"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "byte[]"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "UUID"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "String"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "Object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToJava(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToPhp(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "int"},
|
||||||
|
{"bigint nullable", "bigint", true, "int"},
|
||||||
|
{"smallint nullable", "smallint", true, "int"},
|
||||||
|
|
||||||
|
// Integer types (not nullable)
|
||||||
|
{"integer not nullable", "integer", false, "?int"},
|
||||||
|
{"bigint not nullable", "bigint", false, "?int"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "string"},
|
||||||
|
{"text not nullable", "text", false, "?string"},
|
||||||
|
{"varchar nullable", "varchar", true, "string"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "string"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "?bool"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "float"},
|
||||||
|
{"double precision nullable", "double precision", true, "float"},
|
||||||
|
{"real not nullable", "real", false, "?float"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "\\DateTime"},
|
||||||
|
{"date nullable", "date", true, "\\DateTime"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "?\\DateTime"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "string"},
|
||||||
|
{"bytea not nullable", "bytea", false, "?string"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "array"},
|
||||||
|
{"json not nullable", "json", false, "?array"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "string"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "mixed"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "mixed"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToPhp(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLToRust(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlType string
|
||||||
|
nullable bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Integer types (nullable)
|
||||||
|
{"integer nullable", "integer", true, "i32"},
|
||||||
|
{"bigint nullable", "bigint", true, "i64"},
|
||||||
|
{"smallint nullable", "smallint", true, "i16"},
|
||||||
|
|
||||||
|
// Integer types (not nullable)
|
||||||
|
{"integer not nullable", "integer", false, "Option<i32>"},
|
||||||
|
{"bigint not nullable", "bigint", false, "Option<i64>"},
|
||||||
|
{"smallint not nullable", "smallint", false, "Option<i16>"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"text nullable", "text", true, "String"},
|
||||||
|
{"text not nullable", "text", false, "Option<String>"},
|
||||||
|
{"varchar nullable", "varchar", true, "String"},
|
||||||
|
{"varchar(100) nullable", "varchar(100)", true, "String"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean nullable", "boolean", true, "bool"},
|
||||||
|
{"boolean not nullable", "boolean", false, "Option<bool>"},
|
||||||
|
|
||||||
|
// Float types
|
||||||
|
{"real nullable", "real", true, "f32"},
|
||||||
|
{"double precision nullable", "double precision", true, "f64"},
|
||||||
|
{"real not nullable", "real", false, "Option<f32>"},
|
||||||
|
{"double precision not nullable", "double precision", false, "Option<f64>"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp nullable", "timestamp", true, "NaiveDateTime"},
|
||||||
|
{"timestamptz nullable", "timestamptz", true, "DateTime<Utc>"},
|
||||||
|
{"date nullable", "date", true, "NaiveDate"},
|
||||||
|
{"time nullable", "time", true, "NaiveTime"},
|
||||||
|
{"timestamp not nullable", "timestamp", false, "Option<NaiveDateTime>"},
|
||||||
|
|
||||||
|
// Binary
|
||||||
|
{"bytea nullable", "bytea", true, "Vec<u8>"},
|
||||||
|
{"bytea not nullable", "bytea", false, "Option<Vec<u8>>"},
|
||||||
|
|
||||||
|
// JSON
|
||||||
|
{"json nullable", "json", true, "serde_json::Value"},
|
||||||
|
{"json not nullable", "json", false, "Option<serde_json::Value>"},
|
||||||
|
|
||||||
|
// UUID
|
||||||
|
{"uuid nullable", "uuid", true, "String"},
|
||||||
|
|
||||||
|
// Unknown types
|
||||||
|
{"unknown type nullable", "unknowntype", true, "String"},
|
||||||
|
{"unknown type not nullable", "unknowntype", false, "Option<String>"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SQLToRust(tt.sqlType, tt.nullable)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJavaTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "Integer",
|
||||||
|
"bigint": "Long",
|
||||||
|
"smallint": "Short",
|
||||||
|
"text": "String",
|
||||||
|
"boolean": "Boolean",
|
||||||
|
"double precision": "Double",
|
||||||
|
"numeric": "BigDecimal",
|
||||||
|
"bytea": "byte[]",
|
||||||
|
"timestamp": "Timestamp",
|
||||||
|
"uuid": "UUID",
|
||||||
|
"date": "Date",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedJavaType := range expectedMappings {
|
||||||
|
if javaType, ok := JavaTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("JavaTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if javaType != expectedJavaType {
|
||||||
|
t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(JavaTypeMap) == 0 {
|
||||||
|
t.Error("JavaTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPHPTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "int",
|
||||||
|
"bigint": "int",
|
||||||
|
"text": "string",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "float",
|
||||||
|
"bytea": "string",
|
||||||
|
"timestamp": "\\DateTime",
|
||||||
|
"uuid": "string",
|
||||||
|
"json": "array",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedPHPType := range expectedMappings {
|
||||||
|
if phpType, ok := PHPTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("PHPTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if phpType != expectedPHPType {
|
||||||
|
t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(PHPTypeMap) == 0 {
|
||||||
|
t.Error("PHPTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRustTypeMap(t *testing.T) {
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"integer": "i32",
|
||||||
|
"bigint": "i64",
|
||||||
|
"smallint": "i16",
|
||||||
|
"text": "String",
|
||||||
|
"boolean": "bool",
|
||||||
|
"double precision": "f64",
|
||||||
|
"real": "f32",
|
||||||
|
"bytea": "Vec<u8>",
|
||||||
|
"timestamp": "NaiveDateTime",
|
||||||
|
"timestamptz": "DateTime<Utc>",
|
||||||
|
"date": "NaiveDate",
|
||||||
|
"json": "serde_json::Value",
|
||||||
|
}
|
||||||
|
|
||||||
|
for sqlType, expectedRustType := range expectedMappings {
|
||||||
|
if rustType, ok := RustTypeMap[sqlType]; !ok {
|
||||||
|
t.Errorf("RustTypeMap missing entry for %q", sqlType)
|
||||||
|
} else if rustType != expectedRustType {
|
||||||
|
t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(RustTypeMap) == 0 {
|
||||||
|
t.Error("RustTypeMap is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
558
pkg/diff/diff_test.go
Normal file
558
pkg/diff/diff_test.go
Normal file
@@ -0,0 +1,558 @@
|
|||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompareDatabases(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source *models.Database
|
||||||
|
target *models.Database
|
||||||
|
want func(*DiffResult) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical databases",
|
||||||
|
source: &models.Database{
|
||||||
|
Name: "source",
|
||||||
|
Schemas: []*models.Schema{},
|
||||||
|
},
|
||||||
|
target: &models.Database{
|
||||||
|
Name: "target",
|
||||||
|
Schemas: []*models.Schema{},
|
||||||
|
},
|
||||||
|
want: func(r *DiffResult) bool {
|
||||||
|
return r.Source == "source" && r.Target == "target" &&
|
||||||
|
len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different schemas",
|
||||||
|
source: &models.Database{
|
||||||
|
Name: "source",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
target: &models.Database{
|
||||||
|
Name: "target",
|
||||||
|
Schemas: []*models.Schema{},
|
||||||
|
},
|
||||||
|
want: func(r *DiffResult) bool {
|
||||||
|
return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := CompareDatabases(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("CompareDatabases() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareColumns(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Column
|
||||||
|
target map[string]*models.Column
|
||||||
|
want func(*ColumnDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical columns",
|
||||||
|
source: map[string]*models.Column{},
|
||||||
|
target: map[string]*models.Column{},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing column",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra column",
|
||||||
|
source: map[string]*models.Column{},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified column type",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint"},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Name == "id" &&
|
||||||
|
d.Modified[0].Changes["type"] != nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified column nullable",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "text", NotNull: true},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "text", NotNull: false},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified column length",
|
||||||
|
source: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "varchar", Length: 100},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Column{
|
||||||
|
"name": {Name: "name", Type: "varchar", Length: 255},
|
||||||
|
},
|
||||||
|
want: func(d *ColumnDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareColumns(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareColumns() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareColumnDetails(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source *models.Column
|
||||||
|
target *models.Column
|
||||||
|
want int // number of changes
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical columns",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer"},
|
||||||
|
target: &models.Column{Name: "id", Type: "integer"},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "type change",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer"},
|
||||||
|
target: &models.Column{Name: "id", Type: "bigint"},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "length change",
|
||||||
|
source: &models.Column{Name: "name", Type: "varchar", Length: 100},
|
||||||
|
target: &models.Column{Name: "name", Type: "varchar", Length: 255},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "precision change",
|
||||||
|
source: &models.Column{Name: "price", Type: "numeric", Precision: 10},
|
||||||
|
target: &models.Column{Name: "price", Type: "numeric", Precision: 12},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "scale change",
|
||||||
|
source: &models.Column{Name: "price", Type: "numeric", Scale: 2},
|
||||||
|
target: &models.Column{Name: "price", Type: "numeric", Scale: 4},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null change",
|
||||||
|
source: &models.Column{Name: "name", Type: "text", NotNull: true},
|
||||||
|
target: &models.Column{Name: "name", Type: "text", NotNull: false},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto increment change",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true},
|
||||||
|
target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "primary key change",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true},
|
||||||
|
target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false},
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple changes",
|
||||||
|
source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true},
|
||||||
|
target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false},
|
||||||
|
want: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareColumnDetails(tt.source, tt.target)
|
||||||
|
if len(got) != tt.want {
|
||||||
|
t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareIndexes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Index
|
||||||
|
target map[string]*models.Index
|
||||||
|
want func(*IndexDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical indexes",
|
||||||
|
source: map[string]*models.Index{},
|
||||||
|
target: map[string]*models.Index{},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing index",
|
||||||
|
source: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Index{},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra index",
|
||||||
|
source: map[string]*models.Index{},
|
||||||
|
target: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
|
||||||
|
},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified index uniqueness",
|
||||||
|
source: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Index{
|
||||||
|
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true},
|
||||||
|
},
|
||||||
|
want: func(d *IndexDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareIndexes(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareIndexes() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareConstraints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Constraint
|
||||||
|
target map[string]*models.Constraint
|
||||||
|
want func(*ConstraintDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical constraints",
|
||||||
|
source: map[string]*models.Constraint{},
|
||||||
|
target: map[string]*models.Constraint{},
|
||||||
|
want: func(d *ConstraintDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing constraint",
|
||||||
|
source: map[string]*models.Constraint{
|
||||||
|
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Constraint{},
|
||||||
|
want: func(d *ConstraintDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra constraint",
|
||||||
|
source: map[string]*models.Constraint{},
|
||||||
|
target: map[string]*models.Constraint{
|
||||||
|
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
|
||||||
|
},
|
||||||
|
want: func(d *ConstraintDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareConstraints(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareConstraints() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareRelationships(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source map[string]*models.Relationship
|
||||||
|
target map[string]*models.Relationship
|
||||||
|
want func(*RelationshipDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical relationships",
|
||||||
|
source: map[string]*models.Relationship{},
|
||||||
|
target: map[string]*models.Relationship{},
|
||||||
|
want: func(d *RelationshipDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing relationship",
|
||||||
|
source: map[string]*models.Relationship{
|
||||||
|
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||||
|
},
|
||||||
|
target: map[string]*models.Relationship{},
|
||||||
|
want: func(d *RelationshipDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra relationship",
|
||||||
|
source: map[string]*models.Relationship{},
|
||||||
|
target: map[string]*models.Relationship{
|
||||||
|
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
|
||||||
|
},
|
||||||
|
want: func(d *RelationshipDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareRelationships(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareRelationships() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareTables(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source []*models.Table
|
||||||
|
target []*models.Table
|
||||||
|
want func(*TableDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical tables",
|
||||||
|
source: []*models.Table{},
|
||||||
|
target: []*models.Table{},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing table",
|
||||||
|
source: []*models.Table{
|
||||||
|
{Name: "users", Schema: "public"},
|
||||||
|
},
|
||||||
|
target: []*models.Table{},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "users"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra table",
|
||||||
|
source: []*models.Table{},
|
||||||
|
target: []*models.Table{
|
||||||
|
{Name: "users", Schema: "public"},
|
||||||
|
},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "users"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified table",
|
||||||
|
source: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
target: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: func(d *TableDiff) bool {
|
||||||
|
return len(d.Modified) == 1 && d.Modified[0].Name == "users"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareTables(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareTables() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareSchemas(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
source []*models.Schema
|
||||||
|
target []*models.Schema
|
||||||
|
want func(*SchemaDiff) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical schemas",
|
||||||
|
source: []*models.Schema{},
|
||||||
|
target: []*models.Schema{},
|
||||||
|
want: func(d *SchemaDiff) bool {
|
||||||
|
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing schema",
|
||||||
|
source: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
target: []*models.Schema{},
|
||||||
|
want: func(d *SchemaDiff) bool {
|
||||||
|
return len(d.Missing) == 1 && d.Missing[0].Name == "public"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra schema",
|
||||||
|
source: []*models.Schema{},
|
||||||
|
target: []*models.Schema{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
want: func(d *SchemaDiff) bool {
|
||||||
|
return len(d.Extra) == 1 && d.Extra[0].Name == "public"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := compareSchemas(tt.source, tt.target)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("compareSchemas() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsEmpty(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true},
|
||||||
|
{"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false},
|
||||||
|
{"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false},
|
||||||
|
{"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true},
|
||||||
|
{"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false},
|
||||||
|
{"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true},
|
||||||
|
{"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false},
|
||||||
|
{"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true},
|
||||||
|
{"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isEmpty(tt.v)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isEmpty() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeSummary(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result *DiffResult
|
||||||
|
want func(*Summary) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty diff",
|
||||||
|
result: &DiffResult{
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: func(s *Summary) bool {
|
||||||
|
return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "schemas with differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "schema1"}},
|
||||||
|
Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}},
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: func(s *Summary) bool {
|
||||||
|
return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ComputeSummary(tt.result)
|
||||||
|
if !tt.want(got) {
|
||||||
|
t.Errorf("ComputeSummary() result doesn't match expectations")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
440
pkg/diff/formatters_test.go
Normal file
440
pkg/diff/formatters_test.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatDiff(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source_db",
|
||||||
|
Target: "target_db",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
format OutputFormat
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"summary format", FormatSummary, false},
|
||||||
|
{"json format", FormatJSON, false},
|
||||||
|
{"html format", FormatHTML, false},
|
||||||
|
{"invalid format", OutputFormat("invalid"), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := FormatDiff(result, tt.format, &buf)
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr && buf.Len() == 0 {
|
||||||
|
t.Error("FormatDiff() produced empty output")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummary(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result *DiffResult
|
||||||
|
wantStr []string // strings that should appear in output
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{"source", "target", "No differences found"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with schema differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "schema1"}},
|
||||||
|
Extra: []*models.Schema{{Name: "schema2"}},
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{Name: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with table differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Missing: []*models.Table{{Name: "users"}},
|
||||||
|
Extra: []*models.Table{{Name: "posts"}},
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{Name: "comments", Schema: "public"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(tt.result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
for _, want := range tt.wantStr {
|
||||||
|
if !strings.Contains(output, want) {
|
||||||
|
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatJSON(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "schema1"}},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatJSON(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatJSON() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if output is valid JSON
|
||||||
|
var decoded DiffResult
|
||||||
|
if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil {
|
||||||
|
t.Errorf("formatJSON() produced invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check basic structure
|
||||||
|
if decoded.Source != "source" {
|
||||||
|
t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source")
|
||||||
|
}
|
||||||
|
if decoded.Target != "target" {
|
||||||
|
t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target")
|
||||||
|
}
|
||||||
|
if len(decoded.Schemas.Missing) != 1 {
|
||||||
|
t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatHTML(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
result *DiffResult
|
||||||
|
wantStr []string // HTML elements/content that should appear
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic HTML structure",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{},
|
||||||
|
Extra: []*models.Schema{},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<title>Database Diff Report</title>",
|
||||||
|
"source",
|
||||||
|
"target",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with schema differences",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "missing_schema"}},
|
||||||
|
Extra: []*models.Schema{{Name: "extra_schema"}},
|
||||||
|
Modified: []*SchemaChange{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"missing_schema",
|
||||||
|
"extra_schema",
|
||||||
|
"MISSING",
|
||||||
|
"EXTRA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with table modifications",
|
||||||
|
result: &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: &ColumnDiff{
|
||||||
|
Missing: []*models.Column{{Name: "email", Type: "text"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStr: []string{
|
||||||
|
"public",
|
||||||
|
"users",
|
||||||
|
"email",
|
||||||
|
"text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatHTML(tt.result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatHTML() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
for _, want := range tt.wantStr {
|
||||||
|
if !strings.Contains(output, want) {
|
||||||
|
t.Errorf("formatHTML() output doesn't contain %q", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryWithColumns(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Columns: &ColumnDiff{
|
||||||
|
Missing: []*models.Column{{Name: "email"}},
|
||||||
|
Extra: []*models.Column{{Name: "phone"}, {Name: "address"}},
|
||||||
|
Modified: []*ColumnChange{
|
||||||
|
{Name: "name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
wantStrings := []string{
|
||||||
|
"Columns:",
|
||||||
|
"Missing: 1",
|
||||||
|
"Extra: 2",
|
||||||
|
"Modified: 1",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range wantStrings {
|
||||||
|
if !strings.Contains(output, want) {
|
||||||
|
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryWithIndexes(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Indexes: &IndexDiff{
|
||||||
|
Missing: []*models.Index{{Name: "idx_email"}},
|
||||||
|
Extra: []*models.Index{{Name: "idx_phone"}},
|
||||||
|
Modified: []*IndexChange{{Name: "idx_name"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "Indexes:") {
|
||||||
|
t.Error("formatSummary() output doesn't contain Indexes section")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Missing: 1") {
|
||||||
|
t.Error("formatSummary() output doesn't contain correct missing count")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryWithConstraints(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Modified: []*SchemaChange{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: &TableDiff{
|
||||||
|
Modified: []*TableChange{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Schema: "public",
|
||||||
|
Constraints: &ConstraintDiff{
|
||||||
|
Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}},
|
||||||
|
Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatSummary(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatSummary() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "Constraints:") {
|
||||||
|
t.Error("formatSummary() output doesn't contain Constraints section")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatJSONIndentation(t *testing.T) {
|
||||||
|
result := &DiffResult{
|
||||||
|
Source: "source",
|
||||||
|
Target: "target",
|
||||||
|
Schemas: &SchemaDiff{
|
||||||
|
Missing: []*models.Schema{{Name: "test"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := formatJSON(result, &buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("formatJSON() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that JSON is indented (has newlines and spaces)
|
||||||
|
output := buf.String()
|
||||||
|
if !strings.Contains(output, "\n") {
|
||||||
|
t.Error("formatJSON() should produce indented JSON with newlines")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, " ") {
|
||||||
|
t.Error("formatJSON() should produce indented JSON with spaces")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputFormatConstants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
format OutputFormat
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"summary constant", FormatSummary, "summary"},
|
||||||
|
{"json constant", FormatJSON, "json"},
|
||||||
|
{"html constant", FormatHTML, "html"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if string(tt.format) != tt.want {
|
||||||
|
t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
238
pkg/inspector/inspector_test.go
Normal file
238
pkg/inspector/inspector_test.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewInspector(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
|
||||||
|
if inspector == nil {
|
||||||
|
t.Fatal("NewInspector() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if inspector.db != db {
|
||||||
|
t.Error("NewInspector() database not set correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if inspector.config != config {
|
||||||
|
t.Error("NewInspector() config not set correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInspect(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
report, err := inspector.Inspect()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Inspect() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report == nil {
|
||||||
|
t.Fatal("Inspect() returned nil report")
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Database != db.Name {
|
||||||
|
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.TotalRules != len(config.Rules) {
|
||||||
|
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(report.Violations) == 0 {
|
||||||
|
t.Error("Inspect() returned no violations, expected some results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInspectWithDisabledRules(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
// Disable all rules
|
||||||
|
for name := range config.Rules {
|
||||||
|
rule := config.Rules[name]
|
||||||
|
rule.Enabled = "off"
|
||||||
|
config.Rules[name] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
report, err := inspector.Inspect()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.RulesChecked != 0 {
|
||||||
|
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.RulesSkipped != len(config.Rules) {
|
||||||
|
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInspectWithEnforcedRules(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
// Enable only one rule and enforce it
|
||||||
|
for name := range config.Rules {
|
||||||
|
rule := config.Rules[name]
|
||||||
|
rule.Enabled = "off"
|
||||||
|
config.Rules[name] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryKeyRule := config.Rules["primary_key_naming"]
|
||||||
|
primaryKeyRule.Enabled = "enforce"
|
||||||
|
primaryKeyRule.Pattern = "^id$"
|
||||||
|
config.Rules["primary_key_naming"] = primaryKeyRule
|
||||||
|
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
report, err := inspector.Inspect()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Inspect() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if report.Summary.RulesChecked != 1 {
|
||||||
|
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All results should be at error level for enforced rules
|
||||||
|
for _, violation := range report.Violations {
|
||||||
|
if violation.Level != "error" {
|
||||||
|
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSummary(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
inspector := NewInspector(db, config)
|
||||||
|
|
||||||
|
results := []ValidationResult{
|
||||||
|
{RuleName: "rule1", Passed: true, Level: "error"},
|
||||||
|
{RuleName: "rule2", Passed: false, Level: "error"},
|
||||||
|
{RuleName: "rule3", Passed: false, Level: "warning"},
|
||||||
|
{RuleName: "rule4", Passed: true, Level: "warning"},
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := inspector.generateSummary(results)
|
||||||
|
|
||||||
|
if summary.PassedCount != 2 {
|
||||||
|
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if summary.ErrorCount != 1 {
|
||||||
|
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if summary.WarningCount != 1 {
|
||||||
|
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasErrors(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
report *InspectorReport
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with errors",
|
||||||
|
report: &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
ErrorCount: 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without errors",
|
||||||
|
report: &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
ErrorCount: 0,
|
||||||
|
WarningCount: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.report.HasErrors(); got != tt.want {
|
||||||
|
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValidator(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
functionName string
|
||||||
|
wantExists bool
|
||||||
|
}{
|
||||||
|
{"primary_key_naming", "primary_key_naming", true},
|
||||||
|
{"primary_key_datatype", "primary_key_datatype", true},
|
||||||
|
{"foreign_key_column_naming", "foreign_key_column_naming", true},
|
||||||
|
{"table_regexpr", "table_regexpr", true},
|
||||||
|
{"column_regexpr", "column_regexpr", true},
|
||||||
|
{"reserved_words", "reserved_words", true},
|
||||||
|
{"have_primary_key", "have_primary_key", true},
|
||||||
|
{"orphaned_foreign_key", "orphaned_foreign_key", true},
|
||||||
|
{"circular_dependency", "circular_dependency", true},
|
||||||
|
{"unknown_function", "unknown_function", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, exists := getValidator(tt.functionName)
|
||||||
|
if exists != tt.wantExists {
|
||||||
|
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateResult(t *testing.T) {
|
||||||
|
result := createResult(
|
||||||
|
"test_rule",
|
||||||
|
true,
|
||||||
|
"Test message",
|
||||||
|
"schema.table.column",
|
||||||
|
map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": 42,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.RuleName != "test_rule" {
|
||||||
|
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Passed {
|
||||||
|
t.Error("createResult() Passed = false, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Message != "Test message" {
|
||||||
|
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Location != "schema.table.column" {
|
||||||
|
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Context) != 2 {
|
||||||
|
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
|
||||||
|
}
|
||||||
|
}
|
||||||
366
pkg/inspector/report_test.go
Normal file
366
pkg/inspector/report_test.go
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestReport() *InspectorReport {
|
||||||
|
return &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
TotalRules: 10,
|
||||||
|
RulesChecked: 8,
|
||||||
|
RulesSkipped: 2,
|
||||||
|
ErrorCount: 3,
|
||||||
|
WarningCount: 5,
|
||||||
|
PassedCount: 12,
|
||||||
|
},
|
||||||
|
Violations: []ValidationResult{
|
||||||
|
{
|
||||||
|
RuleName: "primary_key_naming",
|
||||||
|
Level: "error",
|
||||||
|
Message: "Primary key should start with 'id_'",
|
||||||
|
Location: "public.users.user_id",
|
||||||
|
Passed: false,
|
||||||
|
Context: map[string]interface{}{
|
||||||
|
"schema": "public",
|
||||||
|
"table": "users",
|
||||||
|
"column": "user_id",
|
||||||
|
"pattern": "^id_",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
RuleName: "table_name_length",
|
||||||
|
Level: "warning",
|
||||||
|
Message: "Table name too long",
|
||||||
|
Location: "public.very_long_table_name_that_exceeds_limits",
|
||||||
|
Passed: false,
|
||||||
|
Context: map[string]interface{}{
|
||||||
|
"schema": "public",
|
||||||
|
"table": "very_long_table_name_that_exceeds_limits",
|
||||||
|
"length": 44,
|
||||||
|
"max_length": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GeneratedAt: time.Now(),
|
||||||
|
Database: "testdb",
|
||||||
|
SourceFormat: "postgresql",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewMarkdownFormatter(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
if formatter == nil {
|
||||||
|
t.Fatal("NewMarkdownFormatter() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffer is not a terminal, so colors should be disabled
|
||||||
|
if formatter.UseColors {
|
||||||
|
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewJSONFormatter(t *testing.T) {
|
||||||
|
formatter := NewJSONFormatter()
|
||||||
|
|
||||||
|
if formatter == nil {
|
||||||
|
t.Fatal("NewJSONFormatter() returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_Format(t *testing.T) {
|
||||||
|
report := createTestReport()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
output, err := formatter.Format(report)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that output contains expected sections
|
||||||
|
if !strings.Contains(output, "# RelSpec Inspector Report") {
|
||||||
|
t.Error("Markdown output missing header")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Database:") {
|
||||||
|
t.Error("Markdown output missing database field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "testdb") {
|
||||||
|
t.Error("Markdown output missing database name")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Summary") {
|
||||||
|
t.Error("Markdown output missing summary section")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Rules Checked: 8") {
|
||||||
|
t.Error("Markdown output missing rules checked count")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Errors: 3") {
|
||||||
|
t.Error("Markdown output missing error count")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Warnings: 5") {
|
||||||
|
t.Error("Markdown output missing warning count")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Violations") {
|
||||||
|
t.Error("Markdown output missing violations section")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "primary_key_naming") {
|
||||||
|
t.Error("Markdown output missing rule name")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "public.users.user_id") {
|
||||||
|
t.Error("Markdown output missing location")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
|
||||||
|
report := &InspectorReport{
|
||||||
|
Summary: ReportSummary{
|
||||||
|
TotalRules: 10,
|
||||||
|
RulesChecked: 10,
|
||||||
|
RulesSkipped: 0,
|
||||||
|
ErrorCount: 0,
|
||||||
|
WarningCount: 0,
|
||||||
|
PassedCount: 50,
|
||||||
|
},
|
||||||
|
Violations: []ValidationResult{},
|
||||||
|
GeneratedAt: time.Now(),
|
||||||
|
Database: "testdb",
|
||||||
|
SourceFormat: "postgresql",
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
output, err := formatter.Format(report)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "No violations found") {
|
||||||
|
t.Error("Markdown output should indicate no violations")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONFormatter_Format(t *testing.T) {
|
||||||
|
report := createTestReport()
|
||||||
|
formatter := NewJSONFormatter()
|
||||||
|
|
||||||
|
output, err := formatter.Format(report)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's valid JSON
|
||||||
|
var decoded InspectorReport
|
||||||
|
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
|
||||||
|
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check key fields
|
||||||
|
if decoded.Database != "testdb" {
|
||||||
|
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded.Summary.ErrorCount != 3 {
|
||||||
|
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decoded.Violations) != 2 {
|
||||||
|
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
formatter := NewMarkdownFormatter(&buf)
|
||||||
|
|
||||||
|
header := formatter.formatHeader("Test Header")
|
||||||
|
|
||||||
|
if !strings.Contains(header, "# Test Header") {
|
||||||
|
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatBold(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
useColors bool
|
||||||
|
text string
|
||||||
|
wantContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without colors",
|
||||||
|
useColors: false,
|
||||||
|
text: "Bold Text",
|
||||||
|
wantContains: "**Bold Text**",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with colors",
|
||||||
|
useColors: true,
|
||||||
|
text: "Bold Text",
|
||||||
|
wantContains: "Bold Text",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||||
|
result := formatter.formatBold(tt.text)
|
||||||
|
|
||||||
|
if !strings.Contains(result, tt.wantContains) {
|
||||||
|
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_Colorize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
useColors bool
|
||||||
|
text string
|
||||||
|
color string
|
||||||
|
wantColor bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without colors",
|
||||||
|
useColors: false,
|
||||||
|
text: "Test",
|
||||||
|
color: colorRed,
|
||||||
|
wantColor: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with colors",
|
||||||
|
useColors: true,
|
||||||
|
text: "Test",
|
||||||
|
color: colorRed,
|
||||||
|
wantColor: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||||
|
result := formatter.colorize(tt.text, tt.color)
|
||||||
|
|
||||||
|
hasColor := strings.Contains(result, tt.color)
|
||||||
|
if hasColor != tt.wantColor {
|
||||||
|
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, tt.text) {
|
||||||
|
t.Errorf("colorize() doesn't contain original text %q", tt.text)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatContext(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: false}
|
||||||
|
|
||||||
|
context := map[string]interface{}{
|
||||||
|
"schema": "public",
|
||||||
|
"table": "users",
|
||||||
|
"column": "id",
|
||||||
|
"pattern": "^id_",
|
||||||
|
"max_length": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := formatter.formatContext(context)
|
||||||
|
|
||||||
|
// Should not include schema, table, column (they're in location)
|
||||||
|
if strings.Contains(result, "schema") {
|
||||||
|
t.Error("formatContext() should skip schema field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(result, "table=") {
|
||||||
|
t.Error("formatContext() should skip table field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(result, "column=") {
|
||||||
|
t.Error("formatContext() should skip column field")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should include other fields
|
||||||
|
if !strings.Contains(result, "pattern") {
|
||||||
|
t.Error("formatContext() should include pattern field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "max_length") {
|
||||||
|
t.Error("formatContext() should include max_length field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
|
||||||
|
formatter := &MarkdownFormatter{UseColors: false}
|
||||||
|
|
||||||
|
violation := ValidationResult{
|
||||||
|
RuleName: "test_rule",
|
||||||
|
Level: "error",
|
||||||
|
Message: "Test violation message",
|
||||||
|
Location: "public.users.id",
|
||||||
|
Passed: false,
|
||||||
|
Context: map[string]interface{}{
|
||||||
|
"pattern": "^id_",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := formatter.formatViolation(violation, colorRed)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "test_rule") {
|
||||||
|
t.Error("formatViolation() should include rule name")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Test violation message") {
|
||||||
|
t.Error("formatViolation() should include message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "public.users.id") {
|
||||||
|
t.Error("formatViolation() should include location")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Location:") {
|
||||||
|
t.Error("formatViolation() should include Location label")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Message:") {
|
||||||
|
t.Error("formatViolation() should include Message label")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReportFormatConstants(t *testing.T) {
|
||||||
|
// Test that color constants are defined
|
||||||
|
if colorReset == "" {
|
||||||
|
t.Error("colorReset is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorRed == "" {
|
||||||
|
t.Error("colorRed is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorYellow == "" {
|
||||||
|
t.Error("colorYellow is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorGreen == "" {
|
||||||
|
t.Error("colorGreen is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
if colorBold == "" {
|
||||||
|
t.Error("colorBold is not defined")
|
||||||
|
}
|
||||||
|
}
|
||||||
249
pkg/inspector/rules_test.go
Normal file
249
pkg/inspector/rules_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetDefaultConfig(t *testing.T) {
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
t.Fatal("GetDefaultConfig() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Version != "1.0" {
|
||||||
|
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Rules) == 0 {
|
||||||
|
t.Error("GetDefaultConfig() returned no rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all expected rules are present
|
||||||
|
expectedRules := []string{
|
||||||
|
"primary_key_naming",
|
||||||
|
"primary_key_datatype",
|
||||||
|
"primary_key_auto_increment",
|
||||||
|
"foreign_key_column_naming",
|
||||||
|
"foreign_key_constraint_naming",
|
||||||
|
"foreign_key_index",
|
||||||
|
"table_naming_case",
|
||||||
|
"column_naming_case",
|
||||||
|
"table_name_length",
|
||||||
|
"column_name_length",
|
||||||
|
"reserved_keywords",
|
||||||
|
"missing_primary_key",
|
||||||
|
"orphaned_foreign_key",
|
||||||
|
"circular_dependency",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ruleName := range expectedRules {
|
||||||
|
if _, exists := config.Rules[ruleName]; !exists {
|
||||||
|
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfig_NonExistentFile(t *testing.T) {
|
||||||
|
// Try to load a non-existent file
|
||||||
|
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return default config
|
||||||
|
if config == nil {
|
||||||
|
t.Fatal("LoadConfig() returned nil config for non-existent file")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Rules) == 0 {
|
||||||
|
t.Error("LoadConfig() returned config with no rules")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfig_ValidFile(t *testing.T) {
|
||||||
|
// Create a temporary config file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "test-config.yaml")
|
||||||
|
|
||||||
|
configContent := `version: "1.0"
|
||||||
|
rules:
|
||||||
|
primary_key_naming:
|
||||||
|
enabled: "enforce"
|
||||||
|
function: "primary_key_naming"
|
||||||
|
pattern: "^pk_"
|
||||||
|
message: "Primary keys must start with pk_"
|
||||||
|
table_name_length:
|
||||||
|
enabled: "warn"
|
||||||
|
function: "table_name_length"
|
||||||
|
max_length: 50
|
||||||
|
message: "Table name too long"
|
||||||
|
`
|
||||||
|
|
||||||
|
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create test config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfig() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Version != "1.0" {
|
||||||
|
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Rules) != 2 {
|
||||||
|
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check primary_key_naming rule
|
||||||
|
pkRule, exists := config.Rules["primary_key_naming"]
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("LoadConfig() missing primary_key_naming rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkRule.Enabled != "enforce" {
|
||||||
|
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkRule.Pattern != "^pk_" {
|
||||||
|
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check table_name_length rule
|
||||||
|
lengthRule, exists := config.Rules["table_name_length"]
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("LoadConfig() missing table_name_length rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
if lengthRule.MaxLength != 50 {
|
||||||
|
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfig_InvalidYAML(t *testing.T) {
|
||||||
|
// Create a temporary invalid config file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
|
||||||
|
|
||||||
|
invalidContent := `invalid: yaml: content: {[}]`
|
||||||
|
|
||||||
|
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create test config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = LoadConfig(configPath)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("LoadConfig() with invalid YAML did not return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuleIsEnabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enforce is enabled",
|
||||||
|
rule: Rule{Enabled: "enforce"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warn is enabled",
|
||||||
|
rule: Rule{Enabled: "warn"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "off is not enabled",
|
||||||
|
rule: Rule{Enabled: "off"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty is not enabled",
|
||||||
|
rule: Rule{Enabled: ""},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.rule.IsEnabled(); got != tt.want {
|
||||||
|
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRuleIsEnforced(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enforce is enforced",
|
||||||
|
rule: Rule{Enabled: "enforce"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warn is not enforced",
|
||||||
|
rule: Rule{Enabled: "warn"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "off is not enforced",
|
||||||
|
rule: Rule{Enabled: "off"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.rule.IsEnforced(); got != tt.want {
|
||||||
|
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultConfigRuleSettings(t *testing.T) {
|
||||||
|
config := GetDefaultConfig()
|
||||||
|
|
||||||
|
// Test specific rule settings
|
||||||
|
pkNamingRule := config.Rules["primary_key_naming"]
|
||||||
|
if pkNamingRule.Function != "primary_key_naming" {
|
||||||
|
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pkNamingRule.Pattern != "^id_" {
|
||||||
|
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test datatype rule
|
||||||
|
pkDatatypeRule := config.Rules["primary_key_datatype"]
|
||||||
|
if len(pkDatatypeRule.AllowedTypes) == 0 {
|
||||||
|
t.Error("primary_key_datatype has no allowed types")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test length rule
|
||||||
|
tableLengthRule := config.Rules["table_name_length"]
|
||||||
|
if tableLengthRule.MaxLength != 64 {
|
||||||
|
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test reserved keywords rule
|
||||||
|
reservedRule := config.Rules["reserved_keywords"]
|
||||||
|
if !reservedRule.CheckTables {
|
||||||
|
t.Error("reserved_keywords.CheckTables should be true")
|
||||||
|
}
|
||||||
|
if !reservedRule.CheckColumns {
|
||||||
|
t.Error("reserved_keywords.CheckColumns should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
837
pkg/inspector/validators_test.go
Normal file
837
pkg/inspector/validators_test.go
Normal file
@@ -0,0 +1,837 @@
|
|||||||
|
package inspector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to create test database
|
||||||
|
func createTestDatabase() *models.Database {
|
||||||
|
return &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigserial",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
"username": {
|
||||||
|
Name: "username",
|
||||||
|
Type: "varchar(50)",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
},
|
||||||
|
"rid_organization": {
|
||||||
|
Name: "rid_organization",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_users_organization": {
|
||||||
|
Name: "fk_users_organization",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_organization"},
|
||||||
|
ReferencedTable: "organizations",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
ReferencedColumns: []string{"id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Indexes: map[string]*models.Index{
|
||||||
|
"idx_rid_organization": {
|
||||||
|
Name: "idx_rid_organization",
|
||||||
|
Columns: []string{"rid_organization"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "organizations",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigserial",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
AutoIncrement: true,
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
Name: "name",
|
||||||
|
Type: "varchar(100)",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePrimaryKeyNaming(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching pattern id",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^id$",
|
||||||
|
Message: "Primary key should be 'id'",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching pattern id_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^id_",
|
||||||
|
Message: "Primary key should start with 'id_'",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePrimaryKeyDatatype(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "allowed type bigserial",
|
||||||
|
rule: Rule{
|
||||||
|
AllowedTypes: []string{"bigserial", "bigint", "int"},
|
||||||
|
Message: "Primary key should use integer types",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disallowed type",
|
||||||
|
rule: Rule{
|
||||||
|
AllowedTypes: []string{"uuid"},
|
||||||
|
Message: "Primary key should use UUID",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "require auto increment",
|
||||||
|
rule: Rule{
|
||||||
|
RequireAutoIncrement: true,
|
||||||
|
Message: "Primary key should have auto-increment",
|
||||||
|
},
|
||||||
|
wantLen: 0, // No violations - all PKs have auto-increment
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disallow auto increment",
|
||||||
|
rule: Rule{
|
||||||
|
RequireAutoIncrement: false,
|
||||||
|
Message: "Primary key should not have auto-increment",
|
||||||
|
},
|
||||||
|
wantLen: 2, // 2 violations - both PKs have auto-increment
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateForeignKeyColumnNaming(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching pattern rid_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^rid_",
|
||||||
|
Message: "Foreign key columns should start with 'rid_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching pattern fk_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^fk_",
|
||||||
|
Message: "Foreign key columns should start with 'fk_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching pattern fk_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^fk_",
|
||||||
|
Message: "Foreign key constraints should start with 'fk_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching pattern FK_",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^FK_",
|
||||||
|
Message: "Foreign key constraints should start with 'FK_'",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateForeignKeyIndex(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "require index with index present",
|
||||||
|
rule: Rule{
|
||||||
|
RequireIndex: true,
|
||||||
|
Message: "Foreign key columns should have indexes",
|
||||||
|
},
|
||||||
|
wantLen: 1,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no requirement",
|
||||||
|
rule: Rule{
|
||||||
|
RequireIndex: false,
|
||||||
|
Message: "Foreign key index check disabled",
|
||||||
|
},
|
||||||
|
wantLen: 0,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTableNamingCase(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "lowercase snake_case pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[a-z][a-z0-9_]*$",
|
||||||
|
Case: "lowercase",
|
||||||
|
Message: "Table names should be lowercase snake_case",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[A-Z][A-Z0-9_]*$",
|
||||||
|
Case: "uppercase",
|
||||||
|
Message: "Table names should be uppercase",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateTableNamingCase(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||||
|
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumnNamingCase(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "lowercase snake_case pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[a-z][a-z0-9_]*$",
|
||||||
|
Case: "lowercase",
|
||||||
|
Message: "Column names should be lowercase snake_case",
|
||||||
|
},
|
||||||
|
wantLen: 5, // 5 total columns across both tables
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "camelCase pattern",
|
||||||
|
rule: Rule{
|
||||||
|
Pattern: "^[a-z][a-zA-Z0-9]*$",
|
||||||
|
Case: "camelCase",
|
||||||
|
Message: "Column names should be camelCase",
|
||||||
|
},
|
||||||
|
wantLen: 5,
|
||||||
|
wantPass: false, // rid_organization has underscore
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateColumnNamingCase(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTableNameLength(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "max length 64",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 64,
|
||||||
|
Message: "Table name too long",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max length 5",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 5,
|
||||||
|
Message: "Table name too long",
|
||||||
|
},
|
||||||
|
wantLen: 2,
|
||||||
|
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateTableNameLength(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumnNameLength(t *testing.T) {
|
||||||
|
db := createTestDatabase()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
wantPass bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "max length 64",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 64,
|
||||||
|
Message: "Column name too long",
|
||||||
|
},
|
||||||
|
wantLen: 5,
|
||||||
|
wantPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max length 5",
|
||||||
|
rule: Rule{
|
||||||
|
MaxLength: 5,
|
||||||
|
Message: "Column name too long",
|
||||||
|
},
|
||||||
|
wantLen: 5,
|
||||||
|
wantPass: false, // Some columns exceed 5 chars
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateColumnNameLength(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateReservedKeywords(t *testing.T) {
|
||||||
|
// Create a database with reserved keywords
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "user", // "user" is a reserved keyword
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigint",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
"select": { // "select" is a reserved keyword
|
||||||
|
Name: "select",
|
||||||
|
Type: "varchar(50)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rule Rule
|
||||||
|
wantLen int
|
||||||
|
checkPasses bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "check tables only",
|
||||||
|
rule: Rule{
|
||||||
|
CheckTables: true,
|
||||||
|
CheckColumns: false,
|
||||||
|
Message: "Reserved keyword used",
|
||||||
|
},
|
||||||
|
wantLen: 1, // "user" table
|
||||||
|
checkPasses: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check columns only",
|
||||||
|
rule: Rule{
|
||||||
|
CheckTables: false,
|
||||||
|
CheckColumns: true,
|
||||||
|
Message: "Reserved keyword used",
|
||||||
|
},
|
||||||
|
wantLen: 2, // "id", "select" columns (id passes, select fails)
|
||||||
|
checkPasses: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check both",
|
||||||
|
rule: Rule{
|
||||||
|
CheckTables: true,
|
||||||
|
CheckColumns: true,
|
||||||
|
Message: "Reserved keyword used",
|
||||||
|
},
|
||||||
|
wantLen: 3, // "user" table + "id", "select" columns
|
||||||
|
checkPasses: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results := validateReservedKeywords(db, tt.rule, "test_rule")
|
||||||
|
if len(results) != tt.wantLen {
|
||||||
|
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMissingPrimaryKey(t *testing.T) {
|
||||||
|
// Create database with and without primary keys
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "with_pk",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigint",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "without_pk",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"name": {
|
||||||
|
Name: "name",
|
||||||
|
Type: "varchar(50)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
Message: "Table missing primary key",
|
||||||
|
}
|
||||||
|
|
||||||
|
results := validateMissingPrimaryKey(db, rule, "test_rule")
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// First result should pass (with_pk has PK)
|
||||||
|
if results[0].Passed != true {
|
||||||
|
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second result should fail (without_pk missing PK)
|
||||||
|
if results[1].Passed != false {
|
||||||
|
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOrphanedForeignKey(t *testing.T) {
|
||||||
|
// Create database with orphaned FK
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "users",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {
|
||||||
|
Name: "id",
|
||||||
|
Type: "bigint",
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_nonexistent": {
|
||||||
|
Name: "fk_nonexistent",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_organization"},
|
||||||
|
ReferencedTable: "nonexistent_table",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
Message: "Foreign key references non-existent table",
|
||||||
|
}
|
||||||
|
|
||||||
|
results := validateOrphanedForeignKey(db, rule, "test_rule")
|
||||||
|
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[0].Passed != false {
|
||||||
|
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateCircularDependency(t *testing.T) {
|
||||||
|
// Create database with circular dependency
|
||||||
|
db := &models.Database{
|
||||||
|
Name: "testdb",
|
||||||
|
Schemas: []*models.Schema{
|
||||||
|
{
|
||||||
|
Name: "public",
|
||||||
|
Tables: []*models.Table{
|
||||||
|
{
|
||||||
|
Name: "table_a",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_to_b": {
|
||||||
|
Name: "fk_to_b",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
ReferencedTable: "table_b",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "table_b",
|
||||||
|
Columns: map[string]*models.Column{
|
||||||
|
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||||
|
},
|
||||||
|
Constraints: map[string]*models.Constraint{
|
||||||
|
"fk_to_a": {
|
||||||
|
Name: "fk_to_a",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
ReferencedTable: "table_a",
|
||||||
|
ReferencedSchema: "public",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
Message: "Circular dependency detected",
|
||||||
|
}
|
||||||
|
|
||||||
|
results := validateCircularDependency(db, rule, "test_rule")
|
||||||
|
|
||||||
|
// Should detect circular dependency in both tables
|
||||||
|
if len(results) == 0 {
|
||||||
|
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Passed {
|
||||||
|
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeDataType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"varchar(50)", "varchar"},
|
||||||
|
{"decimal(10,2)", "decimal"},
|
||||||
|
{"int", "int"},
|
||||||
|
{"BIGINT", "bigint"},
|
||||||
|
{"VARCHAR(255)", "varchar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
result := normalizeDataType(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
slice []string
|
||||||
|
value string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
|
||||||
|
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
|
||||||
|
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
|
||||||
|
{"empty slice", []string{}, "foo", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := contains(tt.slice, tt.value)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasCycle(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
graph map[string][]string
|
||||||
|
node string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple cycle",
|
||||||
|
graph: map[string][]string{
|
||||||
|
"A": {"B"},
|
||||||
|
"B": {"C"},
|
||||||
|
"C": {"A"},
|
||||||
|
},
|
||||||
|
node: "A",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no cycle",
|
||||||
|
graph: map[string][]string{
|
||||||
|
"A": {"B"},
|
||||||
|
"B": {"C"},
|
||||||
|
"C": {},
|
||||||
|
},
|
||||||
|
node: "A",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "self cycle",
|
||||||
|
graph: map[string][]string{
|
||||||
|
"A": {"A"},
|
||||||
|
},
|
||||||
|
node: "A",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
visited := make(map[string]bool)
|
||||||
|
recStack := make(map[string]bool)
|
||||||
|
result := hasCycle(tt.node, tt.graph, visited, recStack)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatLocation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
schema string
|
||||||
|
table string
|
||||||
|
column string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"public", "users", "id", "public.users.id"},
|
||||||
|
{"public", "users", "", "public.users"},
|
||||||
|
{"public", "", "", "public"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.expected, func(t *testing.T) {
|
||||||
|
result := formatLocation(tt.schema, tt.table, tt.column)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
|
||||||
|
tt.schema, tt.table, tt.column, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
339
pkg/pgsql/datatypes_test.go
Normal file
339
pkg/pgsql/datatypes_test.go
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
package pgsql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqltype string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// PostgreSQL types
|
||||||
|
{"Valid PGSQL bigint", "bigint", true},
|
||||||
|
{"Valid PGSQL integer", "integer", true},
|
||||||
|
{"Valid PGSQL text", "text", true},
|
||||||
|
{"Valid PGSQL boolean", "boolean", true},
|
||||||
|
{"Valid PGSQL double precision", "double precision", true},
|
||||||
|
{"Valid PGSQL bytea", "bytea", true},
|
||||||
|
{"Valid PGSQL uuid", "uuid", true},
|
||||||
|
{"Valid PGSQL jsonb", "jsonb", true},
|
||||||
|
{"Valid PGSQL json", "json", true},
|
||||||
|
{"Valid PGSQL timestamp", "timestamp", true},
|
||||||
|
{"Valid PGSQL date", "date", true},
|
||||||
|
{"Valid PGSQL time", "time", true},
|
||||||
|
{"Valid PGSQL citext", "citext", true},
|
||||||
|
|
||||||
|
// Standard types
|
||||||
|
{"Valid std double", "double", true},
|
||||||
|
{"Valid std blob", "blob", true},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BIGINT", "BIGINT", true},
|
||||||
|
{"Case insensitive TeXt", "TeXt", true},
|
||||||
|
{"Case insensitive BoOlEaN", "BoOlEaN", true},
|
||||||
|
|
||||||
|
// Invalid types
|
||||||
|
{"Invalid type", "invalidtype", false},
|
||||||
|
{"Invalid type varchar", "varchar", false},
|
||||||
|
{"Empty string", "", false},
|
||||||
|
{"Random string", "foobar", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ValidSQLType(tt.sqltype)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
anytype string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Go types to PostgreSQL types
|
||||||
|
{"Go bool to boolean", "bool", "boolean"},
|
||||||
|
{"Go int64 to bigint", "int64", "bigint"},
|
||||||
|
{"Go int to integer", "int", "integer"},
|
||||||
|
{"Go string to text", "string", "text"},
|
||||||
|
{"Go float64 to double precision", "float64", "double precision"},
|
||||||
|
{"Go float32 to real", "float32", "real"},
|
||||||
|
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||||
|
|
||||||
|
// SQL types remain SQL types
|
||||||
|
{"SQL bigint", "bigint", "bigint"},
|
||||||
|
{"SQL integer", "integer", "integer"},
|
||||||
|
{"SQL text", "text", "text"},
|
||||||
|
{"SQL boolean", "boolean", "boolean"},
|
||||||
|
{"SQL uuid", "uuid", "uuid"},
|
||||||
|
{"SQL jsonb", "jsonb", "jsonb"},
|
||||||
|
|
||||||
|
// Case insensitive Go types
|
||||||
|
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||||
|
{"Case insensitive InT64", "InT64", "bigint"},
|
||||||
|
{"Case insensitive STRING", "STRING", "text"},
|
||||||
|
|
||||||
|
// Case insensitive SQL types
|
||||||
|
{"Case insensitive BIGINT", "BIGINT", "bigint"},
|
||||||
|
{"Case insensitive TEXT", "TEXT", "text"},
|
||||||
|
|
||||||
|
// Custom types
|
||||||
|
{"Custom sqluuid", "sqluuid", "uuid"},
|
||||||
|
{"Custom sqljsonb", "sqljsonb", "jsonb"},
|
||||||
|
{"Custom sqlint64", "sqlint64", "bigint"},
|
||||||
|
|
||||||
|
// Unknown types default to text
|
||||||
|
{"Unknown type varchar", "varchar", "text"},
|
||||||
|
{"Unknown type foobar", "foobar", "text"},
|
||||||
|
{"Empty string", "", "text"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetSQLType(tt.anytype)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertSQLType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
anytype string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Go types to PostgreSQL types
|
||||||
|
{"Go bool to boolean", "bool", "boolean"},
|
||||||
|
{"Go int64 to bigint", "int64", "bigint"},
|
||||||
|
{"Go int to integer", "int", "integer"},
|
||||||
|
{"Go string to text", "string", "text"},
|
||||||
|
{"Go float64 to double precision", "float64", "double precision"},
|
||||||
|
{"Go float32 to real", "float32", "real"},
|
||||||
|
{"Go []byte to bytea", "[]byte", "bytea"},
|
||||||
|
|
||||||
|
// SQL types remain SQL types
|
||||||
|
{"SQL bigint", "bigint", "bigint"},
|
||||||
|
{"SQL integer", "integer", "integer"},
|
||||||
|
{"SQL text", "text", "text"},
|
||||||
|
{"SQL boolean", "boolean", "boolean"},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||||
|
{"Case insensitive InT64", "InT64", "bigint"},
|
||||||
|
|
||||||
|
// Unknown types remain unchanged (difference from GetSQLType)
|
||||||
|
{"Unknown type varchar", "varchar", "varchar"},
|
||||||
|
{"Unknown type foobar", "foobar", "foobar"},
|
||||||
|
{"Empty string", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ConvertSQLType(tt.anytype)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsGoType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
typeName string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// Go basic types
|
||||||
|
{"Go bool", "bool", true},
|
||||||
|
{"Go int64", "int64", true},
|
||||||
|
{"Go int", "int", true},
|
||||||
|
{"Go int32", "int32", true},
|
||||||
|
{"Go int16", "int16", true},
|
||||||
|
{"Go int8", "int8", true},
|
||||||
|
{"Go uint", "uint", true},
|
||||||
|
{"Go uint64", "uint64", true},
|
||||||
|
{"Go uint32", "uint32", true},
|
||||||
|
{"Go uint16", "uint16", true},
|
||||||
|
{"Go uint8", "uint8", true},
|
||||||
|
{"Go float64", "float64", true},
|
||||||
|
{"Go float32", "float32", true},
|
||||||
|
{"Go string", "string", true},
|
||||||
|
{"Go []byte", "[]byte", true},
|
||||||
|
|
||||||
|
// Go custom types
|
||||||
|
{"Go complex64", "complex64", true},
|
||||||
|
{"Go complex128", "complex128", true},
|
||||||
|
{"Go uintptr", "uintptr", true},
|
||||||
|
{"Go Pointer", "Pointer", true},
|
||||||
|
|
||||||
|
// Custom SQL types
|
||||||
|
{"Custom sqluuid", "sqluuid", true},
|
||||||
|
{"Custom sqljsonb", "sqljsonb", true},
|
||||||
|
{"Custom sqlint64", "sqlint64", true},
|
||||||
|
{"Custom customdate", "customdate", true},
|
||||||
|
{"Custom customtime", "customtime", true},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BOOL", "BOOL", true},
|
||||||
|
{"Case insensitive InT64", "InT64", true},
|
||||||
|
{"Case insensitive STRING", "STRING", true},
|
||||||
|
|
||||||
|
// SQL types (not Go types)
|
||||||
|
{"SQL bigint", "bigint", false},
|
||||||
|
{"SQL integer", "integer", false},
|
||||||
|
{"SQL text", "text", false},
|
||||||
|
{"SQL boolean", "boolean", false},
|
||||||
|
|
||||||
|
// Invalid types
|
||||||
|
{"Invalid type", "invalidtype", false},
|
||||||
|
{"Empty string", "", false},
|
||||||
|
{"Random string", "foobar", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsGoType(tt.typeName)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStdTypeFromGo(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
typeName string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// Go types to standard SQL types
|
||||||
|
{"Go bool to boolean", "bool", "boolean"},
|
||||||
|
{"Go int64 to bigint", "int64", "bigint"},
|
||||||
|
{"Go int to integer", "int", "integer"},
|
||||||
|
{"Go string to text", "string", "text"},
|
||||||
|
{"Go float64 to double", "float64", "double"},
|
||||||
|
{"Go float32 to double", "float32", "double"},
|
||||||
|
{"Go []byte to blob", "[]byte", "blob"},
|
||||||
|
{"Go int32 to integer", "int32", "integer"},
|
||||||
|
{"Go int16 to smallint", "int16", "smallint"},
|
||||||
|
|
||||||
|
// Custom types
|
||||||
|
{"Custom sqluuid to uuid", "sqluuid", "uuid"},
|
||||||
|
{"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"},
|
||||||
|
{"Custom sqlint64 to bigint", "sqlint64", "bigint"},
|
||||||
|
{"Custom customdate to date", "customdate", "date"},
|
||||||
|
|
||||||
|
// Case insensitive
|
||||||
|
{"Case insensitive BOOL", "BOOL", "boolean"},
|
||||||
|
{"Case insensitive InT64", "InT64", "bigint"},
|
||||||
|
{"Case insensitive STRING", "STRING", "text"},
|
||||||
|
|
||||||
|
// Non-Go types remain unchanged
|
||||||
|
{"SQL bigint unchanged", "bigint", "bigint"},
|
||||||
|
{"SQL integer unchanged", "integer", "integer"},
|
||||||
|
{"Invalid type unchanged", "invalidtype", "invalidtype"},
|
||||||
|
{"Empty string unchanged", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetStdTypeFromGo(tt.typeName)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoToStdTypesMap(t *testing.T) {
|
||||||
|
// Test that the map contains expected entries
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"bool": "boolean",
|
||||||
|
"int64": "bigint",
|
||||||
|
"int": "integer",
|
||||||
|
"string": "text",
|
||||||
|
"float64": "double",
|
||||||
|
"[]byte": "blob",
|
||||||
|
}
|
||||||
|
|
||||||
|
for goType, expectedStd := range expectedMappings {
|
||||||
|
if stdType, ok := GoToStdTypes[goType]; !ok {
|
||||||
|
t.Errorf("GoToStdTypes missing entry for %q", goType)
|
||||||
|
} else if stdType != expectedStd {
|
||||||
|
t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the map is not empty
|
||||||
|
if len(GoToStdTypes) == 0 {
|
||||||
|
t.Error("GoToStdTypes map is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoToPGSQLTypesMap(t *testing.T) {
|
||||||
|
// Test that the map contains expected entries
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"bool": "boolean",
|
||||||
|
"int64": "bigint",
|
||||||
|
"int": "integer",
|
||||||
|
"string": "text",
|
||||||
|
"float64": "double precision",
|
||||||
|
"float32": "real",
|
||||||
|
"[]byte": "bytea",
|
||||||
|
}
|
||||||
|
|
||||||
|
for goType, expectedPG := range expectedMappings {
|
||||||
|
if pgType, ok := GoToPGSQLTypes[goType]; !ok {
|
||||||
|
t.Errorf("GoToPGSQLTypes missing entry for %q", goType)
|
||||||
|
} else if pgType != expectedPG {
|
||||||
|
t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the map is not empty
|
||||||
|
if len(GoToPGSQLTypes) == 0 {
|
||||||
|
t.Error("GoToPGSQLTypes map is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeConversionConsistency(t *testing.T) {
|
||||||
|
// Test that GetSQLType and ConvertSQLType are consistent for known types
|
||||||
|
knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"}
|
||||||
|
|
||||||
|
for _, goType := range knownGoTypes {
|
||||||
|
getSQLResult := GetSQLType(goType)
|
||||||
|
convertResult := ConvertSQLType(goType)
|
||||||
|
|
||||||
|
if getSQLResult != convertResult {
|
||||||
|
t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q",
|
||||||
|
goType, getSQLResult, convertResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) {
|
||||||
|
// Test that GetSQLType returns "text" for unknown types
|
||||||
|
// while ConvertSQLType returns the input unchanged
|
||||||
|
unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"}
|
||||||
|
|
||||||
|
for _, unknown := range unknownTypes {
|
||||||
|
getSQLResult := GetSQLType(unknown)
|
||||||
|
convertResult := ConvertSQLType(unknown)
|
||||||
|
|
||||||
|
if getSQLResult != "text" {
|
||||||
|
t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text")
|
||||||
|
}
|
||||||
|
|
||||||
|
if convertResult != unknown {
|
||||||
|
t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
136
pkg/pgsql/keywords_test.go
Normal file
136
pkg/pgsql/keywords_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package pgsql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetPostgresKeywords(t *testing.T) {
|
||||||
|
keywords := GetPostgresKeywords()
|
||||||
|
|
||||||
|
// Test that keywords are returned
|
||||||
|
if len(keywords) == 0 {
|
||||||
|
t.Fatal("Expected non-empty list of keywords")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that we get all keywords from the map
|
||||||
|
expectedCount := len(postgresKeywords)
|
||||||
|
if len(keywords) != expectedCount {
|
||||||
|
t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that all returned keywords exist in the map
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if !postgresKeywords[keyword] {
|
||||||
|
t.Errorf("Keyword %q not found in postgresKeywords map", keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that no duplicate keywords are returned
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if seen[keyword] {
|
||||||
|
t.Errorf("Duplicate keyword found: %q", keyword)
|
||||||
|
}
|
||||||
|
seen[keyword] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresKeywordsMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keyword string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"SELECT keyword", "select", true},
|
||||||
|
{"FROM keyword", "from", true},
|
||||||
|
{"WHERE keyword", "where", true},
|
||||||
|
{"TABLE keyword", "table", true},
|
||||||
|
{"PRIMARY keyword", "primary", true},
|
||||||
|
{"FOREIGN keyword", "foreign", true},
|
||||||
|
{"CREATE keyword", "create", true},
|
||||||
|
{"DROP keyword", "drop", true},
|
||||||
|
{"ALTER keyword", "alter", true},
|
||||||
|
{"INDEX keyword", "index", true},
|
||||||
|
{"NOT keyword", "not", true},
|
||||||
|
{"NULL keyword", "null", true},
|
||||||
|
{"TRUE keyword", "true", true},
|
||||||
|
{"FALSE keyword", "false", true},
|
||||||
|
{"Non-keyword lowercase", "notakeyword", false},
|
||||||
|
{"Non-keyword uppercase", "NOTAKEYWORD", false},
|
||||||
|
{"Empty string", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := postgresKeywords[tt.keyword]
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresKeywordsMapContent(t *testing.T) {
|
||||||
|
// Test that the map contains expected common keywords
|
||||||
|
commonKeywords := []string{
|
||||||
|
"select", "insert", "update", "delete", "create", "drop", "alter",
|
||||||
|
"table", "index", "view", "schema", "function", "procedure",
|
||||||
|
"primary", "foreign", "key", "constraint", "unique", "check",
|
||||||
|
"null", "not", "and", "or", "like", "in", "between",
|
||||||
|
"join", "inner", "left", "right", "cross", "full", "outer",
|
||||||
|
"where", "having", "group", "order", "limit", "offset",
|
||||||
|
"union", "intersect", "except",
|
||||||
|
"begin", "commit", "rollback", "transaction",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range commonKeywords {
|
||||||
|
if !postgresKeywords[keyword] {
|
||||||
|
t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresKeywordsMapSize(t *testing.T) {
|
||||||
|
// PostgreSQL has a substantial list of reserved keywords
|
||||||
|
// This test ensures the map has a reasonable number of entries
|
||||||
|
minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words
|
||||||
|
|
||||||
|
if len(postgresKeywords) < minExpectedKeywords {
|
||||||
|
t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.",
|
||||||
|
minExpectedKeywords, len(postgresKeywords))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPostgresKeywordsConsistency(t *testing.T) {
|
||||||
|
// Test that calling GetPostgresKeywords multiple times returns consistent results
|
||||||
|
keywords1 := GetPostgresKeywords()
|
||||||
|
keywords2 := GetPostgresKeywords()
|
||||||
|
|
||||||
|
if len(keywords1) != len(keywords2) {
|
||||||
|
t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d",
|
||||||
|
len(keywords1), len(keywords2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a map from both results to compare
|
||||||
|
map1 := make(map[string]bool)
|
||||||
|
map2 := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, k := range keywords1 {
|
||||||
|
map1[k] = true
|
||||||
|
}
|
||||||
|
for _, k := range keywords2 {
|
||||||
|
map2[k] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that both contain the same keywords
|
||||||
|
for k := range map1 {
|
||||||
|
if !map2[k] {
|
||||||
|
t.Errorf("Keyword %q present in first call but not in second", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range map2 {
|
||||||
|
if !map1[k] {
|
||||||
|
t.Errorf("Keyword %q present in second call but not in first", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
490
pkg/reflectutil/helpers_test.go
Normal file
490
pkg/reflectutil/helpers_test.go
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
package reflectutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testStruct struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
Active bool
|
||||||
|
Nested *nestedStruct
|
||||||
|
Private string
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedStruct struct {
|
||||||
|
Value string
|
||||||
|
Count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeref(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
wantValid bool
|
||||||
|
wantKind reflect.Kind
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-pointer int",
|
||||||
|
input: 42,
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Int,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single pointer",
|
||||||
|
input: ptrInt(42),
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Int,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "double pointer",
|
||||||
|
input: ptrPtr(ptrInt(42)),
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Int,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil pointer",
|
||||||
|
input: (*int)(nil),
|
||||||
|
wantValid: false,
|
||||||
|
wantKind: reflect.Ptr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
input: "test",
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.String,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct",
|
||||||
|
input: testStruct{Name: "test"},
|
||||||
|
wantValid: true,
|
||||||
|
wantKind: reflect.Struct,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := reflect.ValueOf(tt.input)
|
||||||
|
got, valid := Deref(v)
|
||||||
|
|
||||||
|
if valid != tt.wantValid {
|
||||||
|
t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Kind() != tt.wantKind {
|
||||||
|
t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDerefInterface(t *testing.T) {
|
||||||
|
i := 42
|
||||||
|
pi := &i
|
||||||
|
ppi := &pi
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
wantKind reflect.Kind
|
||||||
|
}{
|
||||||
|
{"int", 42, reflect.Int},
|
||||||
|
{"pointer to int", &i, reflect.Int},
|
||||||
|
{"double pointer to int", ppi, reflect.Int},
|
||||||
|
{"string", "test", reflect.String},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := DerefInterface(tt.input)
|
||||||
|
if got.Kind() != tt.wantKind {
|
||||||
|
t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFieldValue(t *testing.T) {
|
||||||
|
ts := testStruct{
|
||||||
|
Name: "John",
|
||||||
|
Age: 30,
|
||||||
|
Active: true,
|
||||||
|
Nested: &nestedStruct{Value: "nested", Count: 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
item interface{}
|
||||||
|
field string
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"struct field Name", ts, "Name", "John"},
|
||||||
|
{"struct field Age", ts, "Age", 30},
|
||||||
|
{"struct field Active", ts, "Active", true},
|
||||||
|
{"struct non-existent field", ts, "NonExistent", nil},
|
||||||
|
{"pointer to struct", &ts, "Name", "John"},
|
||||||
|
{"map string key", map[string]string{"key": "value"}, "key", "value"},
|
||||||
|
{"map int key", map[string]int{"count": 42}, "count", 42},
|
||||||
|
{"map non-existent key", map[string]string{"key": "value"}, "missing", nil},
|
||||||
|
{"nil pointer", (*testStruct)(nil), "Name", nil},
|
||||||
|
{"non-struct non-map", 42, "field", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetFieldValue(tt.item, tt.field)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("GetFieldValue() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSliceOrArray(t *testing.T) {
|
||||||
|
arr := [3]int{1, 2, 3}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"slice", []int{1, 2, 3}, true},
|
||||||
|
{"array", arr, true},
|
||||||
|
{"pointer to slice", &[]int{1, 2, 3}, true},
|
||||||
|
{"string", "test", false},
|
||||||
|
{"int", 42, false},
|
||||||
|
{"map", map[string]int{}, false},
|
||||||
|
{"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice
|
||||||
|
{"nil pointer", (*[]int)(nil), false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsSliceOrArray(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"map[string]int", map[string]int{"a": 1}, true},
|
||||||
|
{"map[int]string", map[int]string{1: "a"}, true},
|
||||||
|
{"pointer to map", &map[string]int{"a": 1}, true},
|
||||||
|
{"slice", []int{1, 2, 3}, false},
|
||||||
|
{"string", "test", false},
|
||||||
|
{"int", 42, false},
|
||||||
|
{"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map
|
||||||
|
{"nil pointer", (*map[string]int)(nil), false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsMap(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsMap() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSliceLen(t *testing.T) {
|
||||||
|
arr := [3]int{1, 2, 3}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"slice length 3", []int{1, 2, 3}, 3},
|
||||||
|
{"empty slice", []int{}, 0},
|
||||||
|
{"array length 3", arr, 3},
|
||||||
|
{"pointer to slice", &[]int{1, 2, 3}, 3},
|
||||||
|
{"not a slice", "test", 0},
|
||||||
|
{"int", 42, 0},
|
||||||
|
{"nil slice", ([]int)(nil), 0},
|
||||||
|
{"nil pointer", (*[]int)(nil), 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SliceLen(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("SliceLen() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapLen(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"map length 2", map[string]int{"a": 1, "b": 2}, 2},
|
||||||
|
{"empty map", map[string]int{}, 0},
|
||||||
|
{"pointer to map", &map[string]int{"a": 1}, 1},
|
||||||
|
{"not a map", []int{1, 2, 3}, 0},
|
||||||
|
{"string", "test", 0},
|
||||||
|
{"nil map", (map[string]int)(nil), 0},
|
||||||
|
{"nil pointer", (*map[string]int)(nil), 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapLen(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("MapLen() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSliceToInterfaces(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want []interface{}
|
||||||
|
}{
|
||||||
|
{"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}},
|
||||||
|
{"string slice", []string{"a", "b"}, []interface{}{"a", "b"}},
|
||||||
|
{"empty slice", []int{}, []interface{}{}},
|
||||||
|
{"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}},
|
||||||
|
{"not a slice", "test", []interface{}{}},
|
||||||
|
{"nil slice", ([]int)(nil), []interface{}{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SliceToInterfaces(tt.input)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapKeys(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want []interface{}
|
||||||
|
}{
|
||||||
|
{"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}},
|
||||||
|
{"empty map", map[string]int{}, []interface{}{}},
|
||||||
|
{"not a map", []int{1, 2, 3}, []interface{}{}},
|
||||||
|
{"nil map", (map[string]int)(nil), []interface{}{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapKeys(tt.input)
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want))
|
||||||
|
}
|
||||||
|
// For maps, order is not guaranteed, so just check length
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapValues(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
want int // length of values
|
||||||
|
}{
|
||||||
|
{"map with values", map[string]int{"a": 1, "b": 2}, 2},
|
||||||
|
{"empty map", map[string]int{}, 0},
|
||||||
|
{"not a map", []int{1, 2, 3}, 0},
|
||||||
|
{"nil map", (map[string]int)(nil), 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapValues(tt.input)
|
||||||
|
if len(got) != tt.want {
|
||||||
|
t.Errorf("MapValues() length = %v, want %v", len(got), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapGet(t *testing.T) {
|
||||||
|
m := map[string]int{"a": 1, "b": 2}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
key interface{}
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"existing key", m, "a", 1},
|
||||||
|
{"existing key b", m, "b", 2},
|
||||||
|
{"non-existing key", m, "c", nil},
|
||||||
|
{"pointer to map", &m, "a", 1},
|
||||||
|
{"not a map", []int{1, 2}, 0, nil},
|
||||||
|
{"nil map", (map[string]int)(nil), "a", nil},
|
||||||
|
{"nil pointer", (*map[string]int)(nil), "a", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MapGet(tt.input, tt.key)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("MapGet() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSliceIndex(t *testing.T) {
|
||||||
|
s := []int{10, 20, 30}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
slice interface{}
|
||||||
|
index int
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"index 0", s, 0, 10},
|
||||||
|
{"index 1", s, 1, 20},
|
||||||
|
{"index 2", s, 2, 30},
|
||||||
|
{"negative index", s, -1, nil},
|
||||||
|
{"out of bounds", s, 5, nil},
|
||||||
|
{"pointer to slice", &s, 1, 20},
|
||||||
|
{"not a slice", "test", 0, nil},
|
||||||
|
{"nil slice", ([]int)(nil), 0, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := SliceIndex(tt.slice, tt.index)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SliceIndex() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareValues(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a interface{}
|
||||||
|
b interface{}
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"both nil", nil, nil, 0},
|
||||||
|
{"a nil", nil, 5, -1},
|
||||||
|
{"b nil", 5, nil, 1},
|
||||||
|
{"equal strings", "abc", "abc", 0},
|
||||||
|
{"a less than b strings", "abc", "xyz", -1},
|
||||||
|
{"a greater than b strings", "xyz", "abc", 1},
|
||||||
|
{"equal ints", 5, 5, 0},
|
||||||
|
{"a less than b ints", 3, 7, -1},
|
||||||
|
{"a greater than b ints", 10, 5, 1},
|
||||||
|
{"equal floats", 3.14, 3.14, 0},
|
||||||
|
{"a less than b floats", 2.5, 5.5, -1},
|
||||||
|
{"a greater than b floats", 10.5, 5.5, 1},
|
||||||
|
{"equal uints", uint(5), uint(5), 0},
|
||||||
|
{"different types", "abc", 123, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := CompareValues(tt.a, tt.b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNestedValue(t *testing.T) {
|
||||||
|
nested := map[string]interface{}{
|
||||||
|
"level1": map[string]interface{}{
|
||||||
|
"level2": map[string]interface{}{
|
||||||
|
"value": "deep",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := testStruct{
|
||||||
|
Name: "John",
|
||||||
|
Nested: &nestedStruct{
|
||||||
|
Value: "nested value",
|
||||||
|
Count: 42,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
path string
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"empty path", nested, "", nested},
|
||||||
|
{"single level map", nested, "level1", nested["level1"]},
|
||||||
|
{"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}},
|
||||||
|
{"deep nested map", nested, "level1.level2.value", "deep"},
|
||||||
|
{"struct field", ts, "Name", "John"},
|
||||||
|
{"nested struct field", ts, "Nested", ts.Nested},
|
||||||
|
{"non-existent path", nested, "missing.path", nil},
|
||||||
|
{"nil input", nil, "path", nil},
|
||||||
|
{"partial missing path", nested, "level1.missing", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetNestedValue(tt.input, tt.path)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("GetNestedValue() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeepEqual(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a interface{}
|
||||||
|
b interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"equal ints", 42, 42, true},
|
||||||
|
{"different ints", 42, 43, false},
|
||||||
|
{"equal strings", "test", "test", true},
|
||||||
|
{"different strings", "test", "other", false},
|
||||||
|
{"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true},
|
||||||
|
{"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false},
|
||||||
|
{"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true},
|
||||||
|
{"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false},
|
||||||
|
{"both nil", nil, nil, true},
|
||||||
|
{"one nil", nil, 42, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := DeepEqual(tt.a, tt.b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
func ptrInt(i int) *int {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrPtr(p *int) **int {
|
||||||
|
return &p
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user