2 Commits

Author SHA1 Message Date
5d9770b430 test(pgsql, reflectutil): add comprehensive test coverage
All checks were successful
CI / Test (1.24) (push) Successful in -26m14s
CI / Test (1.25) (push) Successful in -26m3s
CI / Lint (push) Successful in -26m28s
CI / Build (push) Successful in -26m41s
Integration Tests / Integration Tests (push) Successful in -26m21s
* Introduce tests for PostgreSQL data types and keywords.
* Implement tests for reflect utility functions.
* Ensure consistency and correctness of type conversions and keyword mappings.
* Validate behavior for various edge cases and input types.
2026-01-31 22:30:00 +02:00
f2d500f98d feat(merge): 🎉 Add support for constraints and indexes in merge results
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m10s
CI / Lint (push) Successful in -26m33s
CI / Build (push) Successful in -26m40s
Release / Build and Release (push) Successful in -26m23s
Integration Tests / Integration Tests (push) Successful in -25m53s
* Enhance MergeResult to track added constraints and indexes.
* Update merge logic to increment counters for added constraints and indexes.
* Modify GetMergeSummary to include constraints and indexes in the output.
* Add comprehensive tests for merging constraints and indexes.
2026-01-31 21:30:55 +02:00
12 changed files with 4999 additions and 8 deletions

View File

@@ -0,0 +1,714 @@
package commontypes
import (
"testing"
)
func TestExtractBaseType(t *testing.T) {
tests := []struct {
name string
sqlType string
want string
}{
{"varchar with length", "varchar(100)", "varchar"},
{"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"},
{"numeric with precision", "numeric(10,2)", "numeric"},
{"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"},
{"decimal with precision", "decimal(15,3)", "decimal"},
{"char with length", "char(50)", "char"},
{"simple integer", "integer", "integer"},
{"simple text", "text", "text"},
{"bigint", "bigint", "bigint"},
{"With spaces", " varchar(100) ", "varchar"},
{"No parentheses", "boolean", "boolean"},
{"Empty string", "", ""},
{"Mixed case", "VarChar(100)", "varchar"},
{"timestamp with time zone", "timestamp(6) with time zone", "timestamp"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExtractBaseType(tt.sqlType)
if got != tt.want {
t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want)
}
})
}
}
func TestNormalizeType(t *testing.T) {
// NormalizeType is an alias for ExtractBaseType, test that they behave the same
testCases := []string{
"varchar(100)",
"numeric(10,2)",
"integer",
"text",
" VARCHAR(255) ",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
extracted := ExtractBaseType(tc)
normalized := NormalizeType(tc)
if extracted != normalized {
t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q",
tc, extracted, tc, normalized)
}
})
}
}
func TestSQLToGo(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "int32"},
{"bigint nullable", "bigint", true, "int64"},
{"smallint nullable", "smallint", true, "int16"},
{"serial nullable", "serial", true, "int32"},
// Integer types (not nullable)
{"integer not nullable", "integer", false, "*int32"},
{"bigint not nullable", "bigint", false, "*int64"},
{"smallint not nullable", "smallint", false, "*int16"},
// String types (nullable)
{"text nullable", "text", true, "string"},
{"varchar nullable", "varchar", true, "string"},
{"varchar with length nullable", "varchar(100)", true, "string"},
// String types (not nullable)
{"text not nullable", "text", false, "*string"},
{"varchar not nullable", "varchar", false, "*string"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "*bool"},
// Float types
{"real nullable", "real", true, "float32"},
{"double precision nullable", "double precision", true, "float64"},
{"real not nullable", "real", false, "*float32"},
{"double precision not nullable", "double precision", false, "*float64"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "time.Time"},
{"date nullable", "date", true, "time.Time"},
{"timestamp not nullable", "timestamp", false, "*time.Time"},
// Binary
{"bytea nullable", "bytea", true, "[]byte"},
{"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer
// UUID
{"uuid nullable", "uuid", true, "string"},
{"uuid not nullable", "uuid", false, "*string"},
// JSON
{"json nullable", "json", true, "string"},
{"jsonb nullable", "jsonb", true, "string"},
// Array
{"array nullable", "array", true, "[]string"},
{"array not nullable", "array", false, "[]string"}, // Slices don't get pointer
// Unknown types
{"unknown type nullable", "unknowntype", true, "interface{}"},
{"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToGo(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToTypeScript(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types
{"integer nullable", "integer", true, "number"},
{"integer not nullable", "integer", false, "number | null"},
{"bigint nullable", "bigint", true, "number"},
{"bigint not nullable", "bigint", false, "number | null"},
// String types
{"text nullable", "text", true, "string"},
{"text not nullable", "text", false, "string | null"},
{"varchar nullable", "varchar", true, "string"},
{"varchar(100) nullable", "varchar(100)", true, "string"},
// Boolean
{"boolean nullable", "boolean", true, "boolean"},
{"boolean not nullable", "boolean", false, "boolean | null"},
// Float types
{"real nullable", "real", true, "number"},
{"double precision nullable", "double precision", true, "number"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "Date"},
{"date nullable", "date", true, "Date"},
{"timestamp not nullable", "timestamp", false, "Date | null"},
// Binary
{"bytea nullable", "bytea", true, "Buffer"},
{"bytea not nullable", "bytea", false, "Buffer | null"},
// JSON
{"json nullable", "json", true, "any"},
{"jsonb nullable", "jsonb", true, "any"},
// UUID
{"uuid nullable", "uuid", true, "string"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "any"},
{"unknown type not nullable", "unknowntype", false, "any | null"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToTypeScript(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToPython(t *testing.T) {
tests := []struct {
name string
sqlType string
want string
}{
// Integer types
{"integer", "integer", "int"},
{"bigint", "bigint", "int"},
{"smallint", "smallint", "int"},
// String types
{"text", "text", "str"},
{"varchar", "varchar", "str"},
{"varchar(100)", "varchar(100)", "str"},
// Boolean
{"boolean", "boolean", "bool"},
// Float types
{"real", "real", "float"},
{"double precision", "double precision", "float"},
{"numeric", "numeric", "Decimal"},
{"decimal", "decimal", "Decimal"},
// Date/Time types
{"timestamp", "timestamp", "datetime"},
{"date", "date", "date"},
{"time", "time", "time"},
// Binary
{"bytea", "bytea", "bytes"},
// JSON
{"json", "json", "dict"},
{"jsonb", "jsonb", "dict"},
// UUID
{"uuid", "uuid", "UUID"},
// Array
{"array", "array", "list"},
// Unknown types
{"unknown type", "unknowntype", "Any"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToPython(tt.sqlType)
if got != tt.want {
t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want)
}
})
}
}
func TestSQLToCSharp(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "int"},
{"bigint nullable", "bigint", true, "long"},
{"smallint nullable", "smallint", true, "short"},
// Integer types (not nullable - value types get ?)
{"integer not nullable", "integer", false, "int?"},
{"bigint not nullable", "bigint", false, "long?"},
{"smallint not nullable", "smallint", false, "short?"},
// String types (reference types, no ? needed)
{"text nullable", "text", true, "string"},
{"text not nullable", "text", false, "string"},
{"varchar nullable", "varchar", true, "string"},
{"varchar(100) nullable", "varchar(100)", true, "string"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "bool?"},
// Float types
{"real nullable", "real", true, "float"},
{"double precision nullable", "double precision", true, "double"},
{"decimal nullable", "decimal", true, "decimal"},
{"real not nullable", "real", false, "float?"},
{"double precision not nullable", "double precision", false, "double?"},
{"decimal not nullable", "decimal", false, "decimal?"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "DateTime"},
{"date nullable", "date", true, "DateTime"},
{"timestamptz nullable", "timestamptz", true, "DateTimeOffset"},
{"timestamp not nullable", "timestamp", false, "DateTime?"},
{"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"},
// Binary (array type, no ?)
{"bytea nullable", "bytea", true, "byte[]"},
{"bytea not nullable", "bytea", false, "byte[]"},
// UUID
{"uuid nullable", "uuid", true, "Guid"},
{"uuid not nullable", "uuid", false, "Guid?"},
// JSON
{"json nullable", "json", true, "string"},
// Unknown types (object is reference type)
{"unknown type nullable", "unknowntype", true, "object"},
{"unknown type not nullable", "unknowntype", false, "object"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToCSharp(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestNeedsTimeImport(t *testing.T) {
tests := []struct {
name string
goType string
want bool
}{
{"time.Time type", "time.Time", true},
{"pointer to time.Time", "*time.Time", true},
{"int32 type", "int32", false},
{"string type", "string", false},
{"bool type", "bool", false},
{"[]byte type", "[]byte", false},
{"interface{}", "interface{}", false},
{"empty string", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NeedsTimeImport(tt.goType)
if got != tt.want {
t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want)
}
})
}
}
func TestGoTypeMap(t *testing.T) {
// Test that the map contains expected entries
expectedMappings := map[string]string{
"integer": "int32",
"bigint": "int64",
"text": "string",
"boolean": "bool",
"double precision": "float64",
"bytea": "[]byte",
"timestamp": "time.Time",
"uuid": "string",
"json": "string",
}
for sqlType, expectedGoType := range expectedMappings {
if goType, ok := GoTypeMap[sqlType]; !ok {
t.Errorf("GoTypeMap missing entry for %q", sqlType)
} else if goType != expectedGoType {
t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType)
}
}
if len(GoTypeMap) == 0 {
t.Error("GoTypeMap is empty")
}
}
func TestTypeScriptTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "number",
"bigint": "number",
"text": "string",
"boolean": "boolean",
"double precision": "number",
"bytea": "Buffer",
"timestamp": "Date",
"uuid": "string",
"json": "any",
}
for sqlType, expectedTSType := range expectedMappings {
if tsType, ok := TypeScriptTypeMap[sqlType]; !ok {
t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType)
} else if tsType != expectedTSType {
t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType)
}
}
if len(TypeScriptTypeMap) == 0 {
t.Error("TypeScriptTypeMap is empty")
}
}
func TestPythonTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "int",
"bigint": "int",
"text": "str",
"boolean": "bool",
"real": "float",
"numeric": "Decimal",
"bytea": "bytes",
"date": "date",
"uuid": "UUID",
"json": "dict",
}
for sqlType, expectedPyType := range expectedMappings {
if pyType, ok := PythonTypeMap[sqlType]; !ok {
t.Errorf("PythonTypeMap missing entry for %q", sqlType)
} else if pyType != expectedPyType {
t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType)
}
}
if len(PythonTypeMap) == 0 {
t.Error("PythonTypeMap is empty")
}
}
func TestCSharpTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "int",
"bigint": "long",
"smallint": "short",
"text": "string",
"boolean": "bool",
"double precision": "double",
"decimal": "decimal",
"bytea": "byte[]",
"timestamp": "DateTime",
"uuid": "Guid",
"json": "string",
}
for sqlType, expectedCSType := range expectedMappings {
if csType, ok := CSharpTypeMap[sqlType]; !ok {
t.Errorf("CSharpTypeMap missing entry for %q", sqlType)
} else if csType != expectedCSType {
t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType)
}
}
if len(CSharpTypeMap) == 0 {
t.Error("CSharpTypeMap is empty")
}
}
func TestSQLToJava(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types
{"integer nullable", "integer", true, "Integer"},
{"integer not nullable", "integer", false, "Integer"},
{"bigint nullable", "bigint", true, "Long"},
{"smallint nullable", "smallint", true, "Short"},
// String types
{"text nullable", "text", true, "String"},
{"varchar nullable", "varchar", true, "String"},
{"varchar(100) nullable", "varchar(100)", true, "String"},
// Boolean
{"boolean nullable", "boolean", true, "Boolean"},
// Float types
{"real nullable", "real", true, "Float"},
{"double precision nullable", "double precision", true, "Double"},
{"numeric nullable", "numeric", true, "BigDecimal"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "Timestamp"},
{"date nullable", "date", true, "Date"},
{"time nullable", "time", true, "Time"},
// Binary
{"bytea nullable", "bytea", true, "byte[]"},
// UUID
{"uuid nullable", "uuid", true, "UUID"},
// JSON
{"json nullable", "json", true, "String"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "Object"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToJava(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToPhp(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "int"},
{"bigint nullable", "bigint", true, "int"},
{"smallint nullable", "smallint", true, "int"},
// Integer types (not nullable)
{"integer not nullable", "integer", false, "?int"},
{"bigint not nullable", "bigint", false, "?int"},
// String types
{"text nullable", "text", true, "string"},
{"text not nullable", "text", false, "?string"},
{"varchar nullable", "varchar", true, "string"},
{"varchar(100) nullable", "varchar(100)", true, "string"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "?bool"},
// Float types
{"real nullable", "real", true, "float"},
{"double precision nullable", "double precision", true, "float"},
{"real not nullable", "real", false, "?float"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "\\DateTime"},
{"date nullable", "date", true, "\\DateTime"},
{"timestamp not nullable", "timestamp", false, "?\\DateTime"},
// Binary
{"bytea nullable", "bytea", true, "string"},
{"bytea not nullable", "bytea", false, "?string"},
// JSON
{"json nullable", "json", true, "array"},
{"json not nullable", "json", false, "?array"},
// UUID
{"uuid nullable", "uuid", true, "string"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "mixed"},
{"unknown type not nullable", "unknowntype", false, "mixed"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToPhp(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestSQLToRust(t *testing.T) {
tests := []struct {
name string
sqlType string
nullable bool
want string
}{
// Integer types (nullable)
{"integer nullable", "integer", true, "i32"},
{"bigint nullable", "bigint", true, "i64"},
{"smallint nullable", "smallint", true, "i16"},
// Integer types (not nullable)
{"integer not nullable", "integer", false, "Option<i32>"},
{"bigint not nullable", "bigint", false, "Option<i64>"},
{"smallint not nullable", "smallint", false, "Option<i16>"},
// String types
{"text nullable", "text", true, "String"},
{"text not nullable", "text", false, "Option<String>"},
{"varchar nullable", "varchar", true, "String"},
{"varchar(100) nullable", "varchar(100)", true, "String"},
// Boolean
{"boolean nullable", "boolean", true, "bool"},
{"boolean not nullable", "boolean", false, "Option<bool>"},
// Float types
{"real nullable", "real", true, "f32"},
{"double precision nullable", "double precision", true, "f64"},
{"real not nullable", "real", false, "Option<f32>"},
{"double precision not nullable", "double precision", false, "Option<f64>"},
// Date/Time types
{"timestamp nullable", "timestamp", true, "NaiveDateTime"},
{"timestamptz nullable", "timestamptz", true, "DateTime<Utc>"},
{"date nullable", "date", true, "NaiveDate"},
{"time nullable", "time", true, "NaiveTime"},
{"timestamp not nullable", "timestamp", false, "Option<NaiveDateTime>"},
// Binary
{"bytea nullable", "bytea", true, "Vec<u8>"},
{"bytea not nullable", "bytea", false, "Option<Vec<u8>>"},
// JSON
{"json nullable", "json", true, "serde_json::Value"},
{"json not nullable", "json", false, "Option<serde_json::Value>"},
// UUID
{"uuid nullable", "uuid", true, "String"},
// Unknown types
{"unknown type nullable", "unknowntype", true, "String"},
{"unknown type not nullable", "unknowntype", false, "Option<String>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SQLToRust(tt.sqlType, tt.nullable)
if got != tt.want {
t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want)
}
})
}
}
func TestJavaTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "Integer",
"bigint": "Long",
"smallint": "Short",
"text": "String",
"boolean": "Boolean",
"double precision": "Double",
"numeric": "BigDecimal",
"bytea": "byte[]",
"timestamp": "Timestamp",
"uuid": "UUID",
"date": "Date",
}
for sqlType, expectedJavaType := range expectedMappings {
if javaType, ok := JavaTypeMap[sqlType]; !ok {
t.Errorf("JavaTypeMap missing entry for %q", sqlType)
} else if javaType != expectedJavaType {
t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType)
}
}
if len(JavaTypeMap) == 0 {
t.Error("JavaTypeMap is empty")
}
}
func TestPHPTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "int",
"bigint": "int",
"text": "string",
"boolean": "bool",
"double precision": "float",
"bytea": "string",
"timestamp": "\\DateTime",
"uuid": "string",
"json": "array",
}
for sqlType, expectedPHPType := range expectedMappings {
if phpType, ok := PHPTypeMap[sqlType]; !ok {
t.Errorf("PHPTypeMap missing entry for %q", sqlType)
} else if phpType != expectedPHPType {
t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType)
}
}
if len(PHPTypeMap) == 0 {
t.Error("PHPTypeMap is empty")
}
}
func TestRustTypeMap(t *testing.T) {
expectedMappings := map[string]string{
"integer": "i32",
"bigint": "i64",
"smallint": "i16",
"text": "String",
"boolean": "bool",
"double precision": "f64",
"real": "f32",
"bytea": "Vec<u8>",
"timestamp": "NaiveDateTime",
"timestamptz": "DateTime<Utc>",
"date": "NaiveDate",
"json": "serde_json::Value",
}
for sqlType, expectedRustType := range expectedMappings {
if rustType, ok := RustTypeMap[sqlType]; !ok {
t.Errorf("RustTypeMap missing entry for %q", sqlType)
} else if rustType != expectedRustType {
t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType)
}
}
if len(RustTypeMap) == 0 {
t.Error("RustTypeMap is empty")
}
}

558
pkg/diff/diff_test.go Normal file
View File

@@ -0,0 +1,558 @@
package diff
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestCompareDatabases(t *testing.T) {
tests := []struct {
name string
source *models.Database
target *models.Database
want func(*DiffResult) bool
}{
{
name: "identical databases",
source: &models.Database{
Name: "source",
Schemas: []*models.Schema{},
},
target: &models.Database{
Name: "target",
Schemas: []*models.Schema{},
},
want: func(r *DiffResult) bool {
return r.Source == "source" && r.Target == "target" &&
len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0
},
},
{
name: "different schemas",
source: &models.Database{
Name: "source",
Schemas: []*models.Schema{
{Name: "public"},
},
},
target: &models.Database{
Name: "target",
Schemas: []*models.Schema{},
},
want: func(r *DiffResult) bool {
return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := CompareDatabases(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("CompareDatabases() result doesn't match expectations")
}
})
}
}
func TestCompareColumns(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Column
target map[string]*models.Column
want func(*ColumnDiff) bool
}{
{
name: "identical columns",
source: map[string]*models.Column{},
target: map[string]*models.Column{},
want: func(d *ColumnDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing column",
source: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
target: map[string]*models.Column{},
want: func(d *ColumnDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "id"
},
},
{
name: "extra column",
source: map[string]*models.Column{},
target: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
want: func(d *ColumnDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "id"
},
},
{
name: "modified column type",
source: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
target: map[string]*models.Column{
"id": {Name: "id", Type: "bigint"},
},
want: func(d *ColumnDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Name == "id" &&
d.Modified[0].Changes["type"] != nil
},
},
{
name: "modified column nullable",
source: map[string]*models.Column{
"name": {Name: "name", Type: "text", NotNull: true},
},
target: map[string]*models.Column{
"name": {Name: "name", Type: "text", NotNull: false},
},
want: func(d *ColumnDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil
},
},
{
name: "modified column length",
source: map[string]*models.Column{
"name": {Name: "name", Type: "varchar", Length: 100},
},
target: map[string]*models.Column{
"name": {Name: "name", Type: "varchar", Length: 255},
},
want: func(d *ColumnDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareColumns(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareColumns() result doesn't match expectations")
}
})
}
}
func TestCompareColumnDetails(t *testing.T) {
tests := []struct {
name string
source *models.Column
target *models.Column
want int // number of changes
}{
{
name: "identical columns",
source: &models.Column{Name: "id", Type: "integer"},
target: &models.Column{Name: "id", Type: "integer"},
want: 0,
},
{
name: "type change",
source: &models.Column{Name: "id", Type: "integer"},
target: &models.Column{Name: "id", Type: "bigint"},
want: 1,
},
{
name: "length change",
source: &models.Column{Name: "name", Type: "varchar", Length: 100},
target: &models.Column{Name: "name", Type: "varchar", Length: 255},
want: 1,
},
{
name: "precision change",
source: &models.Column{Name: "price", Type: "numeric", Precision: 10},
target: &models.Column{Name: "price", Type: "numeric", Precision: 12},
want: 1,
},
{
name: "scale change",
source: &models.Column{Name: "price", Type: "numeric", Scale: 2},
target: &models.Column{Name: "price", Type: "numeric", Scale: 4},
want: 1,
},
{
name: "not null change",
source: &models.Column{Name: "name", Type: "text", NotNull: true},
target: &models.Column{Name: "name", Type: "text", NotNull: false},
want: 1,
},
{
name: "auto increment change",
source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true},
target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false},
want: 1,
},
{
name: "primary key change",
source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true},
target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false},
want: 1,
},
{
name: "multiple changes",
source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true},
target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false},
want: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareColumnDetails(tt.source, tt.target)
if len(got) != tt.want {
t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want)
}
})
}
}
func TestCompareIndexes(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Index
target map[string]*models.Index
want func(*IndexDiff) bool
}{
{
name: "identical indexes",
source: map[string]*models.Index{},
target: map[string]*models.Index{},
want: func(d *IndexDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing index",
source: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
},
target: map[string]*models.Index{},
want: func(d *IndexDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name"
},
},
{
name: "extra index",
source: map[string]*models.Index{},
target: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}},
},
want: func(d *IndexDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name"
},
},
{
name: "modified index uniqueness",
source: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false},
},
target: map[string]*models.Index{
"idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true},
},
want: func(d *IndexDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareIndexes(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareIndexes() result doesn't match expectations")
}
})
}
}
func TestCompareConstraints(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Constraint
target map[string]*models.Constraint
want func(*ConstraintDiff) bool
}{
{
name: "identical constraints",
source: map[string]*models.Constraint{},
target: map[string]*models.Constraint{},
want: func(d *ConstraintDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing constraint",
source: map[string]*models.Constraint{
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
},
target: map[string]*models.Constraint{},
want: func(d *ConstraintDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id"
},
},
{
name: "extra constraint",
source: map[string]*models.Constraint{},
target: map[string]*models.Constraint{
"pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}},
},
want: func(d *ConstraintDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareConstraints(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareConstraints() result doesn't match expectations")
}
})
}
}
func TestCompareRelationships(t *testing.T) {
tests := []struct {
name string
source map[string]*models.Relationship
target map[string]*models.Relationship
want func(*RelationshipDiff) bool
}{
{
name: "identical relationships",
source: map[string]*models.Relationship{},
target: map[string]*models.Relationship{},
want: func(d *RelationshipDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing relationship",
source: map[string]*models.Relationship{
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
},
target: map[string]*models.Relationship{},
want: func(d *RelationshipDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user"
},
},
{
name: "extra relationship",
source: map[string]*models.Relationship{},
target: map[string]*models.Relationship{
"fk_user": {Name: "fk_user", Type: "FOREIGN KEY"},
},
want: func(d *RelationshipDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareRelationships(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareRelationships() result doesn't match expectations")
}
})
}
}
func TestCompareTables(t *testing.T) {
tests := []struct {
name string
source []*models.Table
target []*models.Table
want func(*TableDiff) bool
}{
{
name: "identical tables",
source: []*models.Table{},
target: []*models.Table{},
want: func(d *TableDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing table",
source: []*models.Table{
{Name: "users", Schema: "public"},
},
target: []*models.Table{},
want: func(d *TableDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "users"
},
},
{
name: "extra table",
source: []*models.Table{},
target: []*models.Table{
{Name: "users", Schema: "public"},
},
want: func(d *TableDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "users"
},
},
{
name: "modified table",
source: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "integer"},
},
},
},
target: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint"},
},
},
},
want: func(d *TableDiff) bool {
return len(d.Modified) == 1 && d.Modified[0].Name == "users"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareTables(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareTables() result doesn't match expectations")
}
})
}
}
func TestCompareSchemas(t *testing.T) {
tests := []struct {
name string
source []*models.Schema
target []*models.Schema
want func(*SchemaDiff) bool
}{
{
name: "identical schemas",
source: []*models.Schema{},
target: []*models.Schema{},
want: func(d *SchemaDiff) bool {
return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0
},
},
{
name: "missing schema",
source: []*models.Schema{
{Name: "public"},
},
target: []*models.Schema{},
want: func(d *SchemaDiff) bool {
return len(d.Missing) == 1 && d.Missing[0].Name == "public"
},
},
{
name: "extra schema",
source: []*models.Schema{},
target: []*models.Schema{
{Name: "public"},
},
want: func(d *SchemaDiff) bool {
return len(d.Extra) == 1 && d.Extra[0].Name == "public"
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := compareSchemas(tt.source, tt.target)
if !tt.want(got) {
t.Errorf("compareSchemas() result doesn't match expectations")
}
})
}
}
func TestIsEmpty(t *testing.T) {
tests := []struct {
name string
v interface{}
want bool
}{
{"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true},
{"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false},
{"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false},
{"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true},
{"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false},
{"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true},
{"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false},
{"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true},
{"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isEmpty(tt.v)
if got != tt.want {
t.Errorf("isEmpty() = %v, want %v", got, tt.want)
}
})
}
}
func TestComputeSummary(t *testing.T) {
tests := []struct {
name string
result *DiffResult
want func(*Summary) bool
}{
{
name: "empty diff",
result: &DiffResult{
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
},
want: func(s *Summary) bool {
return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0
},
},
{
name: "schemas with differences",
result: &DiffResult{
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "schema1"}},
Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}},
Modified: []*SchemaChange{
{Name: "public"},
},
},
},
want: func(s *Summary) bool {
return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ComputeSummary(tt.result)
if !tt.want(got) {
t.Errorf("ComputeSummary() result doesn't match expectations")
}
})
}
}

440
pkg/diff/formatters_test.go Normal file
View File

@@ -0,0 +1,440 @@
package diff
import (
"bytes"
"encoding/json"
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestFormatDiff(t *testing.T) {
result := &DiffResult{
Source: "source_db",
Target: "target_db",
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
}
tests := []struct {
name string
format OutputFormat
wantErr bool
}{
{"summary format", FormatSummary, false},
{"json format", FormatJSON, false},
{"html format", FormatHTML, false},
{"invalid format", OutputFormat("invalid"), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := FormatDiff(result, tt.format, &buf)
if (err != nil) != tt.wantErr {
t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && buf.Len() == 0 {
t.Error("FormatDiff() produced empty output")
}
})
}
}
func TestFormatSummary(t *testing.T) {
tests := []struct {
name string
result *DiffResult
wantStr []string // strings that should appear in output
}{
{
name: "no differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
},
wantStr: []string{"source", "target", "No differences found"},
},
{
name: "with schema differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "schema1"}},
Extra: []*models.Schema{{Name: "schema2"}},
Modified: []*SchemaChange{
{Name: "public"},
},
},
},
wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"},
},
{
name: "with table differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Missing: []*models.Table{{Name: "users"}},
Extra: []*models.Table{{Name: "posts"}},
Modified: []*TableChange{
{Name: "comments", Schema: "public"},
},
},
},
},
},
},
wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := formatSummary(tt.result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
for _, want := range tt.wantStr {
if !strings.Contains(output, want) {
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
}
}
})
}
}
func TestFormatJSON(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "schema1"}},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
}
var buf bytes.Buffer
err := formatJSON(result, &buf)
if err != nil {
t.Errorf("formatJSON() error = %v", err)
return
}
// Check if output is valid JSON
var decoded DiffResult
if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil {
t.Errorf("formatJSON() produced invalid JSON: %v", err)
}
// Check basic structure
if decoded.Source != "source" {
t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source")
}
if decoded.Target != "target" {
t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target")
}
if len(decoded.Schemas.Missing) != 1 {
t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing))
}
}
func TestFormatHTML(t *testing.T) {
tests := []struct {
name string
result *DiffResult
wantStr []string // HTML elements/content that should appear
}{
{
name: "basic HTML structure",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{},
Extra: []*models.Schema{},
Modified: []*SchemaChange{},
},
},
wantStr: []string{
"<!DOCTYPE html>",
"<title>Database Diff Report</title>",
"source",
"target",
},
},
{
name: "with schema differences",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "missing_schema"}},
Extra: []*models.Schema{{Name: "extra_schema"}},
Modified: []*SchemaChange{},
},
},
wantStr: []string{
"<!DOCTYPE html>",
"missing_schema",
"extra_schema",
"MISSING",
"EXTRA",
},
},
{
name: "with table modifications",
result: &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Columns: &ColumnDiff{
Missing: []*models.Column{{Name: "email", Type: "text"}},
},
},
},
},
},
},
},
},
wantStr: []string{
"public",
"users",
"email",
"text",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := formatHTML(tt.result, &buf)
if err != nil {
t.Errorf("formatHTML() error = %v", err)
return
}
output := buf.String()
for _, want := range tt.wantStr {
if !strings.Contains(output, want) {
t.Errorf("formatHTML() output doesn't contain %q", want)
}
}
})
}
}
func TestFormatSummaryWithColumns(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Columns: &ColumnDiff{
Missing: []*models.Column{{Name: "email"}},
Extra: []*models.Column{{Name: "phone"}, {Name: "address"}},
Modified: []*ColumnChange{
{Name: "name"},
},
},
},
},
},
},
},
},
}
var buf bytes.Buffer
err := formatSummary(result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
wantStrings := []string{
"Columns:",
"Missing: 1",
"Extra: 2",
"Modified: 1",
}
for _, want := range wantStrings {
if !strings.Contains(output, want) {
t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output)
}
}
}
func TestFormatSummaryWithIndexes(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Indexes: &IndexDiff{
Missing: []*models.Index{{Name: "idx_email"}},
Extra: []*models.Index{{Name: "idx_phone"}},
Modified: []*IndexChange{{Name: "idx_name"}},
},
},
},
},
},
},
},
}
var buf bytes.Buffer
err := formatSummary(result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
if !strings.Contains(output, "Indexes:") {
t.Error("formatSummary() output doesn't contain Indexes section")
}
if !strings.Contains(output, "Missing: 1") {
t.Error("formatSummary() output doesn't contain correct missing count")
}
}
func TestFormatSummaryWithConstraints(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Modified: []*SchemaChange{
{
Name: "public",
Tables: &TableDiff{
Modified: []*TableChange{
{
Name: "users",
Schema: "public",
Constraints: &ConstraintDiff{
Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}},
Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}},
},
},
},
},
},
},
},
}
var buf bytes.Buffer
err := formatSummary(result, &buf)
if err != nil {
t.Errorf("formatSummary() error = %v", err)
return
}
output := buf.String()
if !strings.Contains(output, "Constraints:") {
t.Error("formatSummary() output doesn't contain Constraints section")
}
}
func TestFormatJSONIndentation(t *testing.T) {
result := &DiffResult{
Source: "source",
Target: "target",
Schemas: &SchemaDiff{
Missing: []*models.Schema{{Name: "test"}},
},
}
var buf bytes.Buffer
err := formatJSON(result, &buf)
if err != nil {
t.Errorf("formatJSON() error = %v", err)
return
}
// Check that JSON is indented (has newlines and spaces)
output := buf.String()
if !strings.Contains(output, "\n") {
t.Error("formatJSON() should produce indented JSON with newlines")
}
if !strings.Contains(output, " ") {
t.Error("formatJSON() should produce indented JSON with spaces")
}
}
func TestOutputFormatConstants(t *testing.T) {
tests := []struct {
name string
format OutputFormat
want string
}{
{"summary constant", FormatSummary, "summary"},
{"json constant", FormatJSON, "json"},
{"html constant", FormatHTML, "html"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.format) != tt.want {
t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want)
}
})
}
}

View File

@@ -0,0 +1,238 @@
package inspector
import (
"testing"
)
func TestNewInspector(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
if inspector == nil {
t.Fatal("NewInspector() returned nil")
}
if inspector.db != db {
t.Error("NewInspector() database not set correctly")
}
if inspector.config != config {
t.Error("NewInspector() config not set correctly")
}
}
func TestInspect(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() returned error: %v", err)
}
if report == nil {
t.Fatal("Inspect() returned nil report")
}
if report.Database != db.Name {
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
}
if report.Summary.TotalRules != len(config.Rules) {
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
}
if len(report.Violations) == 0 {
t.Error("Inspect() returned no violations, expected some results")
}
}
func TestInspectWithDisabledRules(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
// Disable all rules
for name := range config.Rules {
rule := config.Rules[name]
rule.Enabled = "off"
config.Rules[name] = rule
}
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
}
if report.Summary.RulesChecked != 0 {
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
}
if report.Summary.RulesSkipped != len(config.Rules) {
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
}
}
func TestInspectWithEnforcedRules(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
// Enable only one rule and enforce it
for name := range config.Rules {
rule := config.Rules[name]
rule.Enabled = "off"
config.Rules[name] = rule
}
primaryKeyRule := config.Rules["primary_key_naming"]
primaryKeyRule.Enabled = "enforce"
primaryKeyRule.Pattern = "^id$"
config.Rules["primary_key_naming"] = primaryKeyRule
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() returned error: %v", err)
}
if report.Summary.RulesChecked != 1 {
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
}
// All results should be at error level for enforced rules
for _, violation := range report.Violations {
if violation.Level != "error" {
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
}
}
}
func TestGenerateSummary(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
results := []ValidationResult{
{RuleName: "rule1", Passed: true, Level: "error"},
{RuleName: "rule2", Passed: false, Level: "error"},
{RuleName: "rule3", Passed: false, Level: "warning"},
{RuleName: "rule4", Passed: true, Level: "warning"},
}
summary := inspector.generateSummary(results)
if summary.PassedCount != 2 {
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
}
if summary.ErrorCount != 1 {
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
}
if summary.WarningCount != 1 {
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
}
}
func TestHasErrors(t *testing.T) {
tests := []struct {
name string
report *InspectorReport
want bool
}{
{
name: "with errors",
report: &InspectorReport{
Summary: ReportSummary{
ErrorCount: 5,
},
},
want: true,
},
{
name: "without errors",
report: &InspectorReport{
Summary: ReportSummary{
ErrorCount: 0,
WarningCount: 3,
},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.report.HasErrors(); got != tt.want {
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetValidator(t *testing.T) {
tests := []struct {
name string
functionName string
wantExists bool
}{
{"primary_key_naming", "primary_key_naming", true},
{"primary_key_datatype", "primary_key_datatype", true},
{"foreign_key_column_naming", "foreign_key_column_naming", true},
{"table_regexpr", "table_regexpr", true},
{"column_regexpr", "column_regexpr", true},
{"reserved_words", "reserved_words", true},
{"have_primary_key", "have_primary_key", true},
{"orphaned_foreign_key", "orphaned_foreign_key", true},
{"circular_dependency", "circular_dependency", true},
{"unknown_function", "unknown_function", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, exists := getValidator(tt.functionName)
if exists != tt.wantExists {
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
}
})
}
}
func TestCreateResult(t *testing.T) {
result := createResult(
"test_rule",
true,
"Test message",
"schema.table.column",
map[string]interface{}{
"key1": "value1",
"key2": 42,
},
)
if result.RuleName != "test_rule" {
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
}
if !result.Passed {
t.Error("createResult() Passed = false, want true")
}
if result.Message != "Test message" {
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
}
if result.Location != "schema.table.column" {
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
}
if len(result.Context) != 2 {
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
}
}

View File

@@ -0,0 +1,366 @@
package inspector
import (
"bytes"
"encoding/json"
"strings"
"testing"
"time"
)
func createTestReport() *InspectorReport {
return &InspectorReport{
Summary: ReportSummary{
TotalRules: 10,
RulesChecked: 8,
RulesSkipped: 2,
ErrorCount: 3,
WarningCount: 5,
PassedCount: 12,
},
Violations: []ValidationResult{
{
RuleName: "primary_key_naming",
Level: "error",
Message: "Primary key should start with 'id_'",
Location: "public.users.user_id",
Passed: false,
Context: map[string]interface{}{
"schema": "public",
"table": "users",
"column": "user_id",
"pattern": "^id_",
},
},
{
RuleName: "table_name_length",
Level: "warning",
Message: "Table name too long",
Location: "public.very_long_table_name_that_exceeds_limits",
Passed: false,
Context: map[string]interface{}{
"schema": "public",
"table": "very_long_table_name_that_exceeds_limits",
"length": 44,
"max_length": 32,
},
},
},
GeneratedAt: time.Now(),
Database: "testdb",
SourceFormat: "postgresql",
}
}
func TestNewMarkdownFormatter(t *testing.T) {
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
if formatter == nil {
t.Fatal("NewMarkdownFormatter() returned nil")
}
// Buffer is not a terminal, so colors should be disabled
if formatter.UseColors {
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
}
}
func TestNewJSONFormatter(t *testing.T) {
formatter := NewJSONFormatter()
if formatter == nil {
t.Fatal("NewJSONFormatter() returned nil")
}
}
func TestMarkdownFormatter_Format(t *testing.T) {
report := createTestReport()
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
}
// Check that output contains expected sections
if !strings.Contains(output, "# RelSpec Inspector Report") {
t.Error("Markdown output missing header")
}
if !strings.Contains(output, "Database:") {
t.Error("Markdown output missing database field")
}
if !strings.Contains(output, "testdb") {
t.Error("Markdown output missing database name")
}
if !strings.Contains(output, "Summary") {
t.Error("Markdown output missing summary section")
}
if !strings.Contains(output, "Rules Checked: 8") {
t.Error("Markdown output missing rules checked count")
}
if !strings.Contains(output, "Errors: 3") {
t.Error("Markdown output missing error count")
}
if !strings.Contains(output, "Warnings: 5") {
t.Error("Markdown output missing warning count")
}
if !strings.Contains(output, "Violations") {
t.Error("Markdown output missing violations section")
}
if !strings.Contains(output, "primary_key_naming") {
t.Error("Markdown output missing rule name")
}
if !strings.Contains(output, "public.users.user_id") {
t.Error("Markdown output missing location")
}
}
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
report := &InspectorReport{
Summary: ReportSummary{
TotalRules: 10,
RulesChecked: 10,
RulesSkipped: 0,
ErrorCount: 0,
WarningCount: 0,
PassedCount: 50,
},
Violations: []ValidationResult{},
GeneratedAt: time.Now(),
Database: "testdb",
SourceFormat: "postgresql",
}
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
}
if !strings.Contains(output, "No violations found") {
t.Error("Markdown output should indicate no violations")
}
}
func TestJSONFormatter_Format(t *testing.T) {
report := createTestReport()
formatter := NewJSONFormatter()
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
}
// Verify it's valid JSON
var decoded InspectorReport
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
}
// Check key fields
if decoded.Database != "testdb" {
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
}
if decoded.Summary.ErrorCount != 3 {
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
}
if len(decoded.Violations) != 2 {
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
}
}
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
header := formatter.formatHeader("Test Header")
if !strings.Contains(header, "# Test Header") {
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
}
}
func TestMarkdownFormatter_FormatBold(t *testing.T) {
tests := []struct {
name string
useColors bool
text string
wantContains string
}{
{
name: "without colors",
useColors: false,
text: "Bold Text",
wantContains: "**Bold Text**",
},
{
name: "with colors",
useColors: true,
text: "Bold Text",
wantContains: "Bold Text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: tt.useColors}
result := formatter.formatBold(tt.text)
if !strings.Contains(result, tt.wantContains) {
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
}
})
}
}
func TestMarkdownFormatter_Colorize(t *testing.T) {
tests := []struct {
name string
useColors bool
text string
color string
wantColor bool
}{
{
name: "without colors",
useColors: false,
text: "Test",
color: colorRed,
wantColor: false,
},
{
name: "with colors",
useColors: true,
text: "Test",
color: colorRed,
wantColor: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: tt.useColors}
result := formatter.colorize(tt.text, tt.color)
hasColor := strings.Contains(result, tt.color)
if hasColor != tt.wantColor {
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
}
if !strings.Contains(result, tt.text) {
t.Errorf("colorize() doesn't contain original text %q", tt.text)
}
})
}
}
func TestMarkdownFormatter_FormatContext(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: false}
context := map[string]interface{}{
"schema": "public",
"table": "users",
"column": "id",
"pattern": "^id_",
"max_length": 64,
}
result := formatter.formatContext(context)
// Should not include schema, table, column (they're in location)
if strings.Contains(result, "schema") {
t.Error("formatContext() should skip schema field")
}
if strings.Contains(result, "table=") {
t.Error("formatContext() should skip table field")
}
if strings.Contains(result, "column=") {
t.Error("formatContext() should skip column field")
}
// Should include other fields
if !strings.Contains(result, "pattern") {
t.Error("formatContext() should include pattern field")
}
if !strings.Contains(result, "max_length") {
t.Error("formatContext() should include max_length field")
}
}
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: false}
violation := ValidationResult{
RuleName: "test_rule",
Level: "error",
Message: "Test violation message",
Location: "public.users.id",
Passed: false,
Context: map[string]interface{}{
"pattern": "^id_",
},
}
result := formatter.formatViolation(violation, colorRed)
if !strings.Contains(result, "test_rule") {
t.Error("formatViolation() should include rule name")
}
if !strings.Contains(result, "Test violation message") {
t.Error("formatViolation() should include message")
}
if !strings.Contains(result, "public.users.id") {
t.Error("formatViolation() should include location")
}
if !strings.Contains(result, "Location:") {
t.Error("formatViolation() should include Location label")
}
if !strings.Contains(result, "Message:") {
t.Error("formatViolation() should include Message label")
}
}
func TestReportFormatConstants(t *testing.T) {
// Test that color constants are defined
if colorReset == "" {
t.Error("colorReset is not defined")
}
if colorRed == "" {
t.Error("colorRed is not defined")
}
if colorYellow == "" {
t.Error("colorYellow is not defined")
}
if colorGreen == "" {
t.Error("colorGreen is not defined")
}
if colorBold == "" {
t.Error("colorBold is not defined")
}
}

249
pkg/inspector/rules_test.go Normal file
View File

@@ -0,0 +1,249 @@
package inspector
import (
"os"
"path/filepath"
"testing"
)
func TestGetDefaultConfig(t *testing.T) {
config := GetDefaultConfig()
if config == nil {
t.Fatal("GetDefaultConfig() returned nil")
}
if config.Version != "1.0" {
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
}
if len(config.Rules) == 0 {
t.Error("GetDefaultConfig() returned no rules")
}
// Check that all expected rules are present
expectedRules := []string{
"primary_key_naming",
"primary_key_datatype",
"primary_key_auto_increment",
"foreign_key_column_naming",
"foreign_key_constraint_naming",
"foreign_key_index",
"table_naming_case",
"column_naming_case",
"table_name_length",
"column_name_length",
"reserved_keywords",
"missing_primary_key",
"orphaned_foreign_key",
"circular_dependency",
}
for _, ruleName := range expectedRules {
if _, exists := config.Rules[ruleName]; !exists {
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
}
}
}
func TestLoadConfig_NonExistentFile(t *testing.T) {
// Try to load a non-existent file
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
if err != nil {
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
}
// Should return default config
if config == nil {
t.Fatal("LoadConfig() returned nil config for non-existent file")
}
if len(config.Rules) == 0 {
t.Error("LoadConfig() returned config with no rules")
}
}
func TestLoadConfig_ValidFile(t *testing.T) {
// Create a temporary config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test-config.yaml")
configContent := `version: "1.0"
rules:
primary_key_naming:
enabled: "enforce"
function: "primary_key_naming"
pattern: "^pk_"
message: "Primary keys must start with pk_"
table_name_length:
enabled: "warn"
function: "table_name_length"
max_length: 50
message: "Table name too long"
`
err := os.WriteFile(configPath, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
config, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() returned error: %v", err)
}
if config.Version != "1.0" {
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
}
if len(config.Rules) != 2 {
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
}
// Check primary_key_naming rule
pkRule, exists := config.Rules["primary_key_naming"]
if !exists {
t.Fatal("LoadConfig() missing primary_key_naming rule")
}
if pkRule.Enabled != "enforce" {
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
}
if pkRule.Pattern != "^pk_" {
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
}
// Check table_name_length rule
lengthRule, exists := config.Rules["table_name_length"]
if !exists {
t.Fatal("LoadConfig() missing table_name_length rule")
}
if lengthRule.MaxLength != 50 {
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
}
}
func TestLoadConfig_InvalidYAML(t *testing.T) {
// Create a temporary invalid config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
invalidContent := `invalid: yaml: content: {[}]`
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
_, err = LoadConfig(configPath)
if err == nil {
t.Error("LoadConfig() with invalid YAML did not return error")
}
}
func TestRuleIsEnabled(t *testing.T) {
tests := []struct {
name string
rule Rule
want bool
}{
{
name: "enforce is enabled",
rule: Rule{Enabled: "enforce"},
want: true,
},
{
name: "warn is enabled",
rule: Rule{Enabled: "warn"},
want: true,
},
{
name: "off is not enabled",
rule: Rule{Enabled: "off"},
want: false,
},
{
name: "empty is not enabled",
rule: Rule{Enabled: ""},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.rule.IsEnabled(); got != tt.want {
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
}
})
}
}
func TestRuleIsEnforced(t *testing.T) {
tests := []struct {
name string
rule Rule
want bool
}{
{
name: "enforce is enforced",
rule: Rule{Enabled: "enforce"},
want: true,
},
{
name: "warn is not enforced",
rule: Rule{Enabled: "warn"},
want: false,
},
{
name: "off is not enforced",
rule: Rule{Enabled: "off"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.rule.IsEnforced(); got != tt.want {
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
}
})
}
}
func TestDefaultConfigRuleSettings(t *testing.T) {
config := GetDefaultConfig()
// Test specific rule settings
pkNamingRule := config.Rules["primary_key_naming"]
if pkNamingRule.Function != "primary_key_naming" {
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
}
if pkNamingRule.Pattern != "^id_" {
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
}
// Test datatype rule
pkDatatypeRule := config.Rules["primary_key_datatype"]
if len(pkDatatypeRule.AllowedTypes) == 0 {
t.Error("primary_key_datatype has no allowed types")
}
// Test length rule
tableLengthRule := config.Rules["table_name_length"]
if tableLengthRule.MaxLength != 64 {
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
}
// Test reserved keywords rule
reservedRule := config.Rules["reserved_keywords"]
if !reservedRule.CheckTables {
t.Error("reserved_keywords.CheckTables should be true")
}
if !reservedRule.CheckColumns {
t.Error("reserved_keywords.CheckColumns should be true")
}
}

View File

@@ -0,0 +1,837 @@
package inspector
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// Helper function to create test database
func createTestDatabase() *models.Database {
return &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigserial",
IsPrimaryKey: true,
AutoIncrement: true,
},
"username": {
Name: "username",
Type: "varchar(50)",
NotNull: true,
IsPrimaryKey: false,
},
"rid_organization": {
Name: "rid_organization",
Type: "bigint",
NotNull: true,
IsPrimaryKey: false,
},
},
Constraints: map[string]*models.Constraint{
"fk_users_organization": {
Name: "fk_users_organization",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_organization"},
ReferencedTable: "organizations",
ReferencedSchema: "public",
ReferencedColumns: []string{"id"},
},
},
Indexes: map[string]*models.Index{
"idx_rid_organization": {
Name: "idx_rid_organization",
Columns: []string{"rid_organization"},
},
},
},
{
Name: "organizations",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigserial",
IsPrimaryKey: true,
AutoIncrement: true,
},
"name": {
Name: "name",
Type: "varchar(100)",
NotNull: true,
IsPrimaryKey: false,
},
},
},
},
},
},
}
}
func TestValidatePrimaryKeyNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern id",
rule: Rule{
Pattern: "^id$",
Message: "Primary key should be 'id'",
},
wantLen: 2,
wantPass: true,
},
{
name: "non-matching pattern id_",
rule: Rule{
Pattern: "^id_",
Message: "Primary key should start with 'id_'",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidatePrimaryKeyDatatype(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "allowed type bigserial",
rule: Rule{
AllowedTypes: []string{"bigserial", "bigint", "int"},
Message: "Primary key should use integer types",
},
wantLen: 2,
wantPass: true,
},
{
name: "disallowed type",
rule: Rule{
AllowedTypes: []string{"uuid"},
Message: "Primary key should use UUID",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
}{
{
name: "require auto increment",
rule: Rule{
RequireAutoIncrement: true,
Message: "Primary key should have auto-increment",
},
wantLen: 0, // No violations - all PKs have auto-increment
},
{
name: "disallow auto increment",
rule: Rule{
RequireAutoIncrement: false,
Message: "Primary key should not have auto-increment",
},
wantLen: 2, // 2 violations - both PKs have auto-increment
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateForeignKeyColumnNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern rid_",
rule: Rule{
Pattern: "^rid_",
Message: "Foreign key columns should start with 'rid_'",
},
wantLen: 1,
wantPass: true,
},
{
name: "non-matching pattern fk_",
rule: Rule{
Pattern: "^fk_",
Message: "Foreign key columns should start with 'fk_'",
},
wantLen: 1,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern fk_",
rule: Rule{
Pattern: "^fk_",
Message: "Foreign key constraints should start with 'fk_'",
},
wantLen: 1,
wantPass: true,
},
{
name: "non-matching pattern FK_",
rule: Rule{
Pattern: "^FK_",
Message: "Foreign key constraints should start with 'FK_'",
},
wantLen: 1,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateForeignKeyIndex(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "require index with index present",
rule: Rule{
RequireIndex: true,
Message: "Foreign key columns should have indexes",
},
wantLen: 1,
wantPass: true,
},
{
name: "no requirement",
rule: Rule{
RequireIndex: false,
Message: "Foreign key index check disabled",
},
wantLen: 0,
wantPass: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateTableNamingCase(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "lowercase snake_case pattern",
rule: Rule{
Pattern: "^[a-z][a-z0-9_]*$",
Case: "lowercase",
Message: "Table names should be lowercase snake_case",
},
wantLen: 2,
wantPass: true,
},
{
name: "uppercase pattern",
rule: Rule{
Pattern: "^[A-Z][A-Z0-9_]*$",
Case: "uppercase",
Message: "Table names should be uppercase",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateTableNamingCase(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateColumnNamingCase(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "lowercase snake_case pattern",
rule: Rule{
Pattern: "^[a-z][a-z0-9_]*$",
Case: "lowercase",
Message: "Column names should be lowercase snake_case",
},
wantLen: 5, // 5 total columns across both tables
wantPass: true,
},
{
name: "camelCase pattern",
rule: Rule{
Pattern: "^[a-z][a-zA-Z0-9]*$",
Case: "camelCase",
Message: "Column names should be camelCase",
},
wantLen: 5,
wantPass: false, // rid_organization has underscore
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateColumnNamingCase(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateTableNameLength(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "max length 64",
rule: Rule{
MaxLength: 64,
Message: "Table name too long",
},
wantLen: 2,
wantPass: true,
},
{
name: "max length 5",
rule: Rule{
MaxLength: 5,
Message: "Table name too long",
},
wantLen: 2,
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateTableNameLength(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateColumnNameLength(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "max length 64",
rule: Rule{
MaxLength: 64,
Message: "Column name too long",
},
wantLen: 5,
wantPass: true,
},
{
name: "max length 5",
rule: Rule{
MaxLength: 5,
Message: "Column name too long",
},
wantLen: 5,
wantPass: false, // Some columns exceed 5 chars
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateColumnNameLength(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateReservedKeywords(t *testing.T) {
// Create a database with reserved keywords
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "user", // "user" is a reserved keyword
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
"select": { // "select" is a reserved keyword
Name: "select",
Type: "varchar(50)",
},
},
},
},
},
},
}
tests := []struct {
name string
rule Rule
wantLen int
checkPasses bool
}{
{
name: "check tables only",
rule: Rule{
CheckTables: true,
CheckColumns: false,
Message: "Reserved keyword used",
},
wantLen: 1, // "user" table
checkPasses: false,
},
{
name: "check columns only",
rule: Rule{
CheckTables: false,
CheckColumns: true,
Message: "Reserved keyword used",
},
wantLen: 2, // "id", "select" columns (id passes, select fails)
checkPasses: false,
},
{
name: "check both",
rule: Rule{
CheckTables: true,
CheckColumns: true,
Message: "Reserved keyword used",
},
wantLen: 3, // "user" table + "id", "select" columns
checkPasses: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateReservedKeywords(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateMissingPrimaryKey(t *testing.T) {
// Create database with and without primary keys
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "with_pk",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
},
},
{
Name: "without_pk",
Columns: map[string]*models.Column{
"name": {
Name: "name",
Type: "varchar(50)",
},
},
},
},
},
},
}
rule := Rule{
Message: "Table missing primary key",
}
results := validateMissingPrimaryKey(db, rule, "test_rule")
if len(results) != 2 {
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
}
// First result should pass (with_pk has PK)
if results[0].Passed != true {
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
}
// Second result should fail (without_pk missing PK)
if results[1].Passed != false {
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
}
}
func TestValidateOrphanedForeignKey(t *testing.T) {
// Create database with orphaned FK
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
},
Constraints: map[string]*models.Constraint{
"fk_nonexistent": {
Name: "fk_nonexistent",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_organization"},
ReferencedTable: "nonexistent_table",
ReferencedSchema: "public",
},
},
},
},
},
},
}
rule := Rule{
Message: "Foreign key references non-existent table",
}
results := validateOrphanedForeignKey(db, rule, "test_rule")
if len(results) != 1 {
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
}
if results[0].Passed != false {
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
}
}
func TestValidateCircularDependency(t *testing.T) {
// Create database with circular dependency
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "table_a",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
},
Constraints: map[string]*models.Constraint{
"fk_to_b": {
Name: "fk_to_b",
Type: models.ForeignKeyConstraint,
ReferencedTable: "table_b",
ReferencedSchema: "public",
},
},
},
{
Name: "table_b",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
},
Constraints: map[string]*models.Constraint{
"fk_to_a": {
Name: "fk_to_a",
Type: models.ForeignKeyConstraint,
ReferencedTable: "table_a",
ReferencedSchema: "public",
},
},
},
},
},
},
}
rule := Rule{
Message: "Circular dependency detected",
}
results := validateCircularDependency(db, rule, "test_rule")
// Should detect circular dependency in both tables
if len(results) == 0 {
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
}
for _, result := range results {
if result.Passed {
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
}
}
}
func TestNormalizeDataType(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"varchar(50)", "varchar"},
{"decimal(10,2)", "decimal"},
{"int", "int"},
{"BIGINT", "bigint"},
{"VARCHAR(255)", "varchar"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := normalizeDataType(tt.input)
if result != tt.expected {
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
value string
expected bool
}{
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
{"empty slice", []string{}, "foo", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.value)
if result != tt.expected {
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
}
})
}
}
func TestHasCycle(t *testing.T) {
tests := []struct {
name string
graph map[string][]string
node string
expected bool
}{
{
name: "simple cycle",
graph: map[string][]string{
"A": {"B"},
"B": {"C"},
"C": {"A"},
},
node: "A",
expected: true,
},
{
name: "no cycle",
graph: map[string][]string{
"A": {"B"},
"B": {"C"},
"C": {},
},
node: "A",
expected: false,
},
{
name: "self cycle",
graph: map[string][]string{
"A": {"A"},
},
node: "A",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
visited := make(map[string]bool)
recStack := make(map[string]bool)
result := hasCycle(tt.node, tt.graph, visited, recStack)
if result != tt.expected {
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
}
})
}
}
func TestFormatLocation(t *testing.T) {
tests := []struct {
schema string
table string
column string
expected string
}{
{"public", "users", "id", "public.users.id"},
{"public", "users", "", "public.users"},
{"public", "", "", "public"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := formatLocation(tt.schema, tt.table, tt.column)
if result != tt.expected {
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
tt.schema, tt.table, tt.column, result, tt.expected)
}
})
}
}

View File

@@ -15,6 +15,8 @@ type MergeResult struct {
SchemasAdded int
TablesAdded int
ColumnsAdded int
ConstraintsAdded int
IndexesAdded int
RelationsAdded int
DomainsAdded int
EnumsAdded int
@@ -171,6 +173,7 @@ func (r *MergeResult) mergeConstraints(table *models.Table, srcTable *models.Tab
// Constraint doesn't exist, add it
newConst := cloneConstraint(srcConst)
table.Constraints[constName] = newConst
r.ConstraintsAdded++
}
}
}
@@ -193,6 +196,7 @@ func (r *MergeResult) mergeIndexes(table *models.Table, srcTable *models.Table)
// Index doesn't exist, add it
newIdx := cloneIndex(srcIdx)
table.Indexes[idxName] = newIdx
r.IndexesAdded++
}
}
}
@@ -598,6 +602,8 @@ func GetMergeSummary(result *MergeResult) string {
fmt.Sprintf("Schemas added: %d", result.SchemasAdded),
fmt.Sprintf("Tables added: %d", result.TablesAdded),
fmt.Sprintf("Columns added: %d", result.ColumnsAdded),
fmt.Sprintf("Constraints added: %d", result.ConstraintsAdded),
fmt.Sprintf("Indexes added: %d", result.IndexesAdded),
fmt.Sprintf("Views added: %d", result.ViewsAdded),
fmt.Sprintf("Sequences added: %d", result.SequencesAdded),
fmt.Sprintf("Enums added: %d", result.EnumsAdded),
@@ -606,6 +612,7 @@ func GetMergeSummary(result *MergeResult) string {
}
totalAdded := result.SchemasAdded + result.TablesAdded + result.ColumnsAdded +
result.ConstraintsAdded + result.IndexesAdded +
result.ViewsAdded + result.SequencesAdded + result.EnumsAdded +
result.RelationsAdded + result.DomainsAdded

617
pkg/merge/merge_test.go Normal file
View File

@@ -0,0 +1,617 @@
package merge
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
func TestMergeDatabases_NilInputs(t *testing.T) {
result := MergeDatabases(nil, nil, nil)
if result == nil {
t.Fatal("Expected non-nil result")
}
if result.SchemasAdded != 0 {
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
}
}
func TestMergeDatabases_NewSchema(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{Name: "auth"},
},
}
result := MergeDatabases(target, source, nil)
if result.SchemasAdded != 1 {
t.Errorf("Expected 1 schema added, got %d", result.SchemasAdded)
}
if len(target.Schemas) != 2 {
t.Errorf("Expected 2 schemas in target, got %d", len(target.Schemas))
}
}
func TestMergeDatabases_ExistingSchema(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{Name: "public"},
},
}
result := MergeDatabases(target, source, nil)
if result.SchemasAdded != 0 {
t.Errorf("Expected 0 schemas added, got %d", result.SchemasAdded)
}
if len(target.Schemas) != 1 {
t.Errorf("Expected 1 schema in target, got %d", len(target.Schemas))
}
}
func TestMergeTables_NewTable(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "posts",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.TablesAdded != 1 {
t.Errorf("Expected 1 table added, got %d", result.TablesAdded)
}
if len(target.Schemas[0].Tables) != 2 {
t.Errorf("Expected 2 tables in target schema, got %d", len(target.Schemas[0].Tables))
}
}
func TestMergeColumns_NewColumn(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "int"},
},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar"},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ColumnsAdded != 1 {
t.Errorf("Expected 1 column added, got %d", result.ColumnsAdded)
}
if len(target.Schemas[0].Tables[0].Columns) != 2 {
t.Errorf("Expected 2 columns in target table, got %d", len(target.Schemas[0].Tables[0].Columns))
}
}
func TestMergeConstraints_NewConstraint(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ConstraintsAdded != 1 {
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
}
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
}
}
func TestMergeConstraints_NilConstraintsMap(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: nil, // Nil map
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ConstraintsAdded != 1 {
t.Errorf("Expected 1 constraint added, got %d", result.ConstraintsAdded)
}
if target.Schemas[0].Tables[0].Constraints == nil {
t.Error("Expected constraints map to be initialized")
}
if len(target.Schemas[0].Tables[0].Constraints) != 1 {
t.Errorf("Expected 1 constraint in target table, got %d", len(target.Schemas[0].Tables[0].Constraints))
}
}
func TestMergeIndexes_NewIndex(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
}
}
func TestMergeIndexes_NilIndexesMap(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: nil, // Nil map
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
if target.Schemas[0].Tables[0].Indexes == nil {
t.Error("Expected indexes map to be initialized")
}
if len(target.Schemas[0].Tables[0].Indexes) != 1 {
t.Errorf("Expected 1 index in target table, got %d", len(target.Schemas[0].Tables[0].Indexes))
}
}
func TestMergeOptions_SkipTableNames(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "migrations",
Schema: "public",
Columns: map[string]*models.Column{},
},
},
},
},
}
opts := &MergeOptions{
SkipTableNames: map[string]bool{
"migrations": true,
},
}
result := MergeDatabases(target, source, opts)
if result.TablesAdded != 0 {
t.Errorf("Expected 0 tables added (skipped), got %d", result.TablesAdded)
}
if len(target.Schemas[0].Tables) != 1 {
t.Errorf("Expected 1 table in target schema, got %d", len(target.Schemas[0].Tables))
}
}
func TestMergeViews_NewView(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Views: []*models.View{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Views: []*models.View{
{
Name: "user_summary",
Schema: "public",
Definition: "SELECT * FROM users",
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.ViewsAdded != 1 {
t.Errorf("Expected 1 view added, got %d", result.ViewsAdded)
}
if len(target.Schemas[0].Views) != 1 {
t.Errorf("Expected 1 view in target schema, got %d", len(target.Schemas[0].Views))
}
}
func TestMergeEnums_NewEnum(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Enums: []*models.Enum{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Enums: []*models.Enum{
{
Name: "user_role",
Schema: "public",
Values: []string{"admin", "user"},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.EnumsAdded != 1 {
t.Errorf("Expected 1 enum added, got %d", result.EnumsAdded)
}
if len(target.Schemas[0].Enums) != 1 {
t.Errorf("Expected 1 enum in target schema, got %d", len(target.Schemas[0].Enums))
}
}
func TestMergeDomains_NewDomain(t *testing.T) {
target := &models.Database{
Domains: []*models.Domain{},
}
source := &models.Database{
Domains: []*models.Domain{
{
Name: "auth",
Description: "Authentication domain",
},
},
}
result := MergeDatabases(target, source, nil)
if result.DomainsAdded != 1 {
t.Errorf("Expected 1 domain added, got %d", result.DomainsAdded)
}
if len(target.Domains) != 1 {
t.Errorf("Expected 1 domain in target, got %d", len(target.Domains))
}
}
func TestMergeRelations_NewRelation(t *testing.T) {
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Relations: []*models.Relationship{},
},
},
}
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Relations: []*models.Relationship{
{
Name: "fk_posts_user",
Type: models.OneToMany,
FromTable: "posts",
FromColumns: []string{"user_id"},
ToTable: "users",
ToColumns: []string{"id"},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
if result.RelationsAdded != 1 {
t.Errorf("Expected 1 relation added, got %d", result.RelationsAdded)
}
if len(target.Schemas[0].Relations) != 1 {
t.Errorf("Expected 1 relation in target schema, got %d", len(target.Schemas[0].Relations))
}
}
func TestGetMergeSummary(t *testing.T) {
result := &MergeResult{
SchemasAdded: 1,
TablesAdded: 2,
ColumnsAdded: 5,
ConstraintsAdded: 3,
IndexesAdded: 2,
ViewsAdded: 1,
}
summary := GetMergeSummary(result)
if summary == "" {
t.Error("Expected non-empty summary")
}
if len(summary) < 50 {
t.Errorf("Summary seems too short: %s", summary)
}
}
func TestGetMergeSummary_Nil(t *testing.T) {
summary := GetMergeSummary(nil)
if summary == "" {
t.Error("Expected non-empty summary for nil result")
}
}
func TestComplexMerge(t *testing.T) {
// Target with existing structure
target := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "int"},
},
Constraints: map[string]*models.Constraint{},
Indexes: map[string]*models.Index{},
},
},
},
},
}
// Source with new columns, constraints, and indexes
source := &models.Database{
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Schema: "public",
Columns: map[string]*models.Column{
"email": {Name: "email", Type: "varchar"},
"guid": {Name: "guid", Type: "uuid"},
},
Constraints: map[string]*models.Constraint{
"ukey_users_email": {
Type: models.UniqueConstraint,
Columns: []string{"email"},
Name: "ukey_users_email",
},
"ukey_users_guid": {
Type: models.UniqueConstraint,
Columns: []string{"guid"},
Name: "ukey_users_guid",
},
},
Indexes: map[string]*models.Index{
"idx_users_email": {
Name: "idx_users_email",
Columns: []string{"email"},
},
},
},
},
},
},
}
result := MergeDatabases(target, source, nil)
// Verify counts
if result.ColumnsAdded != 2 {
t.Errorf("Expected 2 columns added, got %d", result.ColumnsAdded)
}
if result.ConstraintsAdded != 2 {
t.Errorf("Expected 2 constraints added, got %d", result.ConstraintsAdded)
}
if result.IndexesAdded != 1 {
t.Errorf("Expected 1 index added, got %d", result.IndexesAdded)
}
// Verify target has merged data
table := target.Schemas[0].Tables[0]
if len(table.Columns) != 3 {
t.Errorf("Expected 3 columns in merged table, got %d", len(table.Columns))
}
if len(table.Constraints) != 2 {
t.Errorf("Expected 2 constraints in merged table, got %d", len(table.Constraints))
}
if len(table.Indexes) != 1 {
t.Errorf("Expected 1 index in merged table, got %d", len(table.Indexes))
}
// Verify specific constraint
if _, exists := table.Constraints["ukey_users_guid"]; !exists {
t.Error("Expected ukey_users_guid constraint to exist")
}
}

339
pkg/pgsql/datatypes_test.go Normal file
View File

@@ -0,0 +1,339 @@
package pgsql
import (
"testing"
)
func TestValidSQLType(t *testing.T) {
tests := []struct {
name string
sqltype string
want bool
}{
// PostgreSQL types
{"Valid PGSQL bigint", "bigint", true},
{"Valid PGSQL integer", "integer", true},
{"Valid PGSQL text", "text", true},
{"Valid PGSQL boolean", "boolean", true},
{"Valid PGSQL double precision", "double precision", true},
{"Valid PGSQL bytea", "bytea", true},
{"Valid PGSQL uuid", "uuid", true},
{"Valid PGSQL jsonb", "jsonb", true},
{"Valid PGSQL json", "json", true},
{"Valid PGSQL timestamp", "timestamp", true},
{"Valid PGSQL date", "date", true},
{"Valid PGSQL time", "time", true},
{"Valid PGSQL citext", "citext", true},
// Standard types
{"Valid std double", "double", true},
{"Valid std blob", "blob", true},
// Case insensitive
{"Case insensitive BIGINT", "BIGINT", true},
{"Case insensitive TeXt", "TeXt", true},
{"Case insensitive BoOlEaN", "BoOlEaN", true},
// Invalid types
{"Invalid type", "invalidtype", false},
{"Invalid type varchar", "varchar", false},
{"Empty string", "", false},
{"Random string", "foobar", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidSQLType(tt.sqltype)
if got != tt.want {
t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want)
}
})
}
}
func TestGetSQLType(t *testing.T) {
tests := []struct {
name string
anytype string
want string
}{
// Go types to PostgreSQL types
{"Go bool to boolean", "bool", "boolean"},
{"Go int64 to bigint", "int64", "bigint"},
{"Go int to integer", "int", "integer"},
{"Go string to text", "string", "text"},
{"Go float64 to double precision", "float64", "double precision"},
{"Go float32 to real", "float32", "real"},
{"Go []byte to bytea", "[]byte", "bytea"},
// SQL types remain SQL types
{"SQL bigint", "bigint", "bigint"},
{"SQL integer", "integer", "integer"},
{"SQL text", "text", "text"},
{"SQL boolean", "boolean", "boolean"},
{"SQL uuid", "uuid", "uuid"},
{"SQL jsonb", "jsonb", "jsonb"},
// Case insensitive Go types
{"Case insensitive BOOL", "BOOL", "boolean"},
{"Case insensitive InT64", "InT64", "bigint"},
{"Case insensitive STRING", "STRING", "text"},
// Case insensitive SQL types
{"Case insensitive BIGINT", "BIGINT", "bigint"},
{"Case insensitive TEXT", "TEXT", "text"},
// Custom types
{"Custom sqluuid", "sqluuid", "uuid"},
{"Custom sqljsonb", "sqljsonb", "jsonb"},
{"Custom sqlint64", "sqlint64", "bigint"},
// Unknown types default to text
{"Unknown type varchar", "varchar", "text"},
{"Unknown type foobar", "foobar", "text"},
{"Empty string", "", "text"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetSQLType(tt.anytype)
if got != tt.want {
t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
}
})
}
}
func TestConvertSQLType(t *testing.T) {
tests := []struct {
name string
anytype string
want string
}{
// Go types to PostgreSQL types
{"Go bool to boolean", "bool", "boolean"},
{"Go int64 to bigint", "int64", "bigint"},
{"Go int to integer", "int", "integer"},
{"Go string to text", "string", "text"},
{"Go float64 to double precision", "float64", "double precision"},
{"Go float32 to real", "float32", "real"},
{"Go []byte to bytea", "[]byte", "bytea"},
// SQL types remain SQL types
{"SQL bigint", "bigint", "bigint"},
{"SQL integer", "integer", "integer"},
{"SQL text", "text", "text"},
{"SQL boolean", "boolean", "boolean"},
// Case insensitive
{"Case insensitive BOOL", "BOOL", "boolean"},
{"Case insensitive InT64", "InT64", "bigint"},
// Unknown types remain unchanged (difference from GetSQLType)
{"Unknown type varchar", "varchar", "varchar"},
{"Unknown type foobar", "foobar", "foobar"},
{"Empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ConvertSQLType(tt.anytype)
if got != tt.want {
t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want)
}
})
}
}
func TestIsGoType(t *testing.T) {
tests := []struct {
name string
typeName string
want bool
}{
// Go basic types
{"Go bool", "bool", true},
{"Go int64", "int64", true},
{"Go int", "int", true},
{"Go int32", "int32", true},
{"Go int16", "int16", true},
{"Go int8", "int8", true},
{"Go uint", "uint", true},
{"Go uint64", "uint64", true},
{"Go uint32", "uint32", true},
{"Go uint16", "uint16", true},
{"Go uint8", "uint8", true},
{"Go float64", "float64", true},
{"Go float32", "float32", true},
{"Go string", "string", true},
{"Go []byte", "[]byte", true},
// Go custom types
{"Go complex64", "complex64", true},
{"Go complex128", "complex128", true},
{"Go uintptr", "uintptr", true},
{"Go Pointer", "Pointer", true},
// Custom SQL types
{"Custom sqluuid", "sqluuid", true},
{"Custom sqljsonb", "sqljsonb", true},
{"Custom sqlint64", "sqlint64", true},
{"Custom customdate", "customdate", true},
{"Custom customtime", "customtime", true},
// Case insensitive
{"Case insensitive BOOL", "BOOL", true},
{"Case insensitive InT64", "InT64", true},
{"Case insensitive STRING", "STRING", true},
// SQL types (not Go types)
{"SQL bigint", "bigint", false},
{"SQL integer", "integer", false},
{"SQL text", "text", false},
{"SQL boolean", "boolean", false},
// Invalid types
{"Invalid type", "invalidtype", false},
{"Empty string", "", false},
{"Random string", "foobar", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsGoType(tt.typeName)
if got != tt.want {
t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want)
}
})
}
}
func TestGetStdTypeFromGo(t *testing.T) {
tests := []struct {
name string
typeName string
want string
}{
// Go types to standard SQL types
{"Go bool to boolean", "bool", "boolean"},
{"Go int64 to bigint", "int64", "bigint"},
{"Go int to integer", "int", "integer"},
{"Go string to text", "string", "text"},
{"Go float64 to double", "float64", "double"},
{"Go float32 to double", "float32", "double"},
{"Go []byte to blob", "[]byte", "blob"},
{"Go int32 to integer", "int32", "integer"},
{"Go int16 to smallint", "int16", "smallint"},
// Custom types
{"Custom sqluuid to uuid", "sqluuid", "uuid"},
{"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"},
{"Custom sqlint64 to bigint", "sqlint64", "bigint"},
{"Custom customdate to date", "customdate", "date"},
// Case insensitive
{"Case insensitive BOOL", "BOOL", "boolean"},
{"Case insensitive InT64", "InT64", "bigint"},
{"Case insensitive STRING", "STRING", "text"},
// Non-Go types remain unchanged
{"SQL bigint unchanged", "bigint", "bigint"},
{"SQL integer unchanged", "integer", "integer"},
{"Invalid type unchanged", "invalidtype", "invalidtype"},
{"Empty string unchanged", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetStdTypeFromGo(tt.typeName)
if got != tt.want {
t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want)
}
})
}
}
func TestGoToStdTypesMap(t *testing.T) {
// Test that the map contains expected entries
expectedMappings := map[string]string{
"bool": "boolean",
"int64": "bigint",
"int": "integer",
"string": "text",
"float64": "double",
"[]byte": "blob",
}
for goType, expectedStd := range expectedMappings {
if stdType, ok := GoToStdTypes[goType]; !ok {
t.Errorf("GoToStdTypes missing entry for %q", goType)
} else if stdType != expectedStd {
t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd)
}
}
// Test that the map is not empty
if len(GoToStdTypes) == 0 {
t.Error("GoToStdTypes map is empty")
}
}
func TestGoToPGSQLTypesMap(t *testing.T) {
// Test that the map contains expected entries
expectedMappings := map[string]string{
"bool": "boolean",
"int64": "bigint",
"int": "integer",
"string": "text",
"float64": "double precision",
"float32": "real",
"[]byte": "bytea",
}
for goType, expectedPG := range expectedMappings {
if pgType, ok := GoToPGSQLTypes[goType]; !ok {
t.Errorf("GoToPGSQLTypes missing entry for %q", goType)
} else if pgType != expectedPG {
t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG)
}
}
// Test that the map is not empty
if len(GoToPGSQLTypes) == 0 {
t.Error("GoToPGSQLTypes map is empty")
}
}
func TestTypeConversionConsistency(t *testing.T) {
// Test that GetSQLType and ConvertSQLType are consistent for known types
knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"}
for _, goType := range knownGoTypes {
getSQLResult := GetSQLType(goType)
convertResult := ConvertSQLType(goType)
if getSQLResult != convertResult {
t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q",
goType, getSQLResult, convertResult)
}
}
}
func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) {
// Test that GetSQLType returns "text" for unknown types
// while ConvertSQLType returns the input unchanged
unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"}
for _, unknown := range unknownTypes {
getSQLResult := GetSQLType(unknown)
convertResult := ConvertSQLType(unknown)
if getSQLResult != "text" {
t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text")
}
if convertResult != unknown {
t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown)
}
}
}

136
pkg/pgsql/keywords_test.go Normal file
View File

@@ -0,0 +1,136 @@
package pgsql
import (
"testing"
)
func TestGetPostgresKeywords(t *testing.T) {
keywords := GetPostgresKeywords()
// Test that keywords are returned
if len(keywords) == 0 {
t.Fatal("Expected non-empty list of keywords")
}
// Test that we get all keywords from the map
expectedCount := len(postgresKeywords)
if len(keywords) != expectedCount {
t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords))
}
// Test that all returned keywords exist in the map
for _, keyword := range keywords {
if !postgresKeywords[keyword] {
t.Errorf("Keyword %q not found in postgresKeywords map", keyword)
}
}
// Test that no duplicate keywords are returned
seen := make(map[string]bool)
for _, keyword := range keywords {
if seen[keyword] {
t.Errorf("Duplicate keyword found: %q", keyword)
}
seen[keyword] = true
}
}
func TestPostgresKeywordsMap(t *testing.T) {
tests := []struct {
name string
keyword string
want bool
}{
{"SELECT keyword", "select", true},
{"FROM keyword", "from", true},
{"WHERE keyword", "where", true},
{"TABLE keyword", "table", true},
{"PRIMARY keyword", "primary", true},
{"FOREIGN keyword", "foreign", true},
{"CREATE keyword", "create", true},
{"DROP keyword", "drop", true},
{"ALTER keyword", "alter", true},
{"INDEX keyword", "index", true},
{"NOT keyword", "not", true},
{"NULL keyword", "null", true},
{"TRUE keyword", "true", true},
{"FALSE keyword", "false", true},
{"Non-keyword lowercase", "notakeyword", false},
{"Non-keyword uppercase", "NOTAKEYWORD", false},
{"Empty string", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := postgresKeywords[tt.keyword]
if got != tt.want {
t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want)
}
})
}
}
func TestPostgresKeywordsMapContent(t *testing.T) {
// Test that the map contains expected common keywords
commonKeywords := []string{
"select", "insert", "update", "delete", "create", "drop", "alter",
"table", "index", "view", "schema", "function", "procedure",
"primary", "foreign", "key", "constraint", "unique", "check",
"null", "not", "and", "or", "like", "in", "between",
"join", "inner", "left", "right", "cross", "full", "outer",
"where", "having", "group", "order", "limit", "offset",
"union", "intersect", "except",
"begin", "commit", "rollback", "transaction",
}
for _, keyword := range commonKeywords {
if !postgresKeywords[keyword] {
t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword)
}
}
}
func TestPostgresKeywordsMapSize(t *testing.T) {
// PostgreSQL has a substantial list of reserved keywords
// This test ensures the map has a reasonable number of entries
minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words
if len(postgresKeywords) < minExpectedKeywords {
t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.",
minExpectedKeywords, len(postgresKeywords))
}
}
func TestGetPostgresKeywordsConsistency(t *testing.T) {
// Test that calling GetPostgresKeywords multiple times returns consistent results
keywords1 := GetPostgresKeywords()
keywords2 := GetPostgresKeywords()
if len(keywords1) != len(keywords2) {
t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d",
len(keywords1), len(keywords2))
}
// Create a map from both results to compare
map1 := make(map[string]bool)
map2 := make(map[string]bool)
for _, k := range keywords1 {
map1[k] = true
}
for _, k := range keywords2 {
map2[k] = true
}
// Check that both contain the same keywords
for k := range map1 {
if !map2[k] {
t.Errorf("Keyword %q present in first call but not in second", k)
}
}
for k := range map2 {
if !map1[k] {
t.Errorf("Keyword %q present in second call but not in first", k)
}
}
}

View File

@@ -0,0 +1,490 @@
package reflectutil
import (
"reflect"
"testing"
)
type testStruct struct {
Name string
Age int
Active bool
Nested *nestedStruct
Private string
}
type nestedStruct struct {
Value string
Count int
}
func TestDeref(t *testing.T) {
tests := []struct {
name string
input interface{}
wantValid bool
wantKind reflect.Kind
}{
{
name: "non-pointer int",
input: 42,
wantValid: true,
wantKind: reflect.Int,
},
{
name: "single pointer",
input: ptrInt(42),
wantValid: true,
wantKind: reflect.Int,
},
{
name: "double pointer",
input: ptrPtr(ptrInt(42)),
wantValid: true,
wantKind: reflect.Int,
},
{
name: "nil pointer",
input: (*int)(nil),
wantValid: false,
wantKind: reflect.Ptr,
},
{
name: "string",
input: "test",
wantValid: true,
wantKind: reflect.String,
},
{
name: "struct",
input: testStruct{Name: "test"},
wantValid: true,
wantKind: reflect.Struct,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := reflect.ValueOf(tt.input)
got, valid := Deref(v)
if valid != tt.wantValid {
t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid)
}
if got.Kind() != tt.wantKind {
t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind)
}
})
}
}
func TestDerefInterface(t *testing.T) {
i := 42
pi := &i
ppi := &pi
tests := []struct {
name string
input interface{}
wantKind reflect.Kind
}{
{"int", 42, reflect.Int},
{"pointer to int", &i, reflect.Int},
{"double pointer to int", ppi, reflect.Int},
{"string", "test", reflect.String},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DerefInterface(tt.input)
if got.Kind() != tt.wantKind {
t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind)
}
})
}
}
func TestGetFieldValue(t *testing.T) {
ts := testStruct{
Name: "John",
Age: 30,
Active: true,
Nested: &nestedStruct{Value: "nested", Count: 5},
}
tests := []struct {
name string
item interface{}
field string
want interface{}
}{
{"struct field Name", ts, "Name", "John"},
{"struct field Age", ts, "Age", 30},
{"struct field Active", ts, "Active", true},
{"struct non-existent field", ts, "NonExistent", nil},
{"pointer to struct", &ts, "Name", "John"},
{"map string key", map[string]string{"key": "value"}, "key", "value"},
{"map int key", map[string]int{"count": 42}, "count", 42},
{"map non-existent key", map[string]string{"key": "value"}, "missing", nil},
{"nil pointer", (*testStruct)(nil), "Name", nil},
{"non-struct non-map", 42, "field", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetFieldValue(tt.item, tt.field)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetFieldValue() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsSliceOrArray(t *testing.T) {
arr := [3]int{1, 2, 3}
tests := []struct {
name string
input interface{}
want bool
}{
{"slice", []int{1, 2, 3}, true},
{"array", arr, true},
{"pointer to slice", &[]int{1, 2, 3}, true},
{"string", "test", false},
{"int", 42, false},
{"map", map[string]int{}, false},
{"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice
{"nil pointer", (*[]int)(nil), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsSliceOrArray(tt.input)
if got != tt.want {
t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsMap(t *testing.T) {
tests := []struct {
name string
input interface{}
want bool
}{
{"map[string]int", map[string]int{"a": 1}, true},
{"map[int]string", map[int]string{1: "a"}, true},
{"pointer to map", &map[string]int{"a": 1}, true},
{"slice", []int{1, 2, 3}, false},
{"string", "test", false},
{"int", 42, false},
{"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map
{"nil pointer", (*map[string]int)(nil), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsMap(tt.input)
if got != tt.want {
t.Errorf("IsMap() = %v, want %v", got, tt.want)
}
})
}
}
func TestSliceLen(t *testing.T) {
arr := [3]int{1, 2, 3}
tests := []struct {
name string
input interface{}
want int
}{
{"slice length 3", []int{1, 2, 3}, 3},
{"empty slice", []int{}, 0},
{"array length 3", arr, 3},
{"pointer to slice", &[]int{1, 2, 3}, 3},
{"not a slice", "test", 0},
{"int", 42, 0},
{"nil slice", ([]int)(nil), 0},
{"nil pointer", (*[]int)(nil), 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SliceLen(tt.input)
if got != tt.want {
t.Errorf("SliceLen() = %v, want %v", got, tt.want)
}
})
}
}
func TestMapLen(t *testing.T) {
tests := []struct {
name string
input interface{}
want int
}{
{"map length 2", map[string]int{"a": 1, "b": 2}, 2},
{"empty map", map[string]int{}, 0},
{"pointer to map", &map[string]int{"a": 1}, 1},
{"not a map", []int{1, 2, 3}, 0},
{"string", "test", 0},
{"nil map", (map[string]int)(nil), 0},
{"nil pointer", (*map[string]int)(nil), 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapLen(tt.input)
if got != tt.want {
t.Errorf("MapLen() = %v, want %v", got, tt.want)
}
})
}
}
func TestSliceToInterfaces(t *testing.T) {
tests := []struct {
name string
input interface{}
want []interface{}
}{
{"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}},
{"string slice", []string{"a", "b"}, []interface{}{"a", "b"}},
{"empty slice", []int{}, []interface{}{}},
{"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}},
{"not a slice", "test", []interface{}{}},
{"nil slice", ([]int)(nil), []interface{}{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SliceToInterfaces(tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want)
}
})
}
}
func TestMapKeys(t *testing.T) {
tests := []struct {
name string
input interface{}
want []interface{}
}{
{"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}},
{"empty map", map[string]int{}, []interface{}{}},
{"not a map", []int{1, 2, 3}, []interface{}{}},
{"nil map", (map[string]int)(nil), []interface{}{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapKeys(tt.input)
if len(got) != len(tt.want) {
t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want))
}
// For maps, order is not guaranteed, so just check length
})
}
}
func TestMapValues(t *testing.T) {
tests := []struct {
name string
input interface{}
want int // length of values
}{
{"map with values", map[string]int{"a": 1, "b": 2}, 2},
{"empty map", map[string]int{}, 0},
{"not a map", []int{1, 2, 3}, 0},
{"nil map", (map[string]int)(nil), 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapValues(tt.input)
if len(got) != tt.want {
t.Errorf("MapValues() length = %v, want %v", len(got), tt.want)
}
})
}
}
func TestMapGet(t *testing.T) {
m := map[string]int{"a": 1, "b": 2}
tests := []struct {
name string
input interface{}
key interface{}
want interface{}
}{
{"existing key", m, "a", 1},
{"existing key b", m, "b", 2},
{"non-existing key", m, "c", nil},
{"pointer to map", &m, "a", 1},
{"not a map", []int{1, 2}, 0, nil},
{"nil map", (map[string]int)(nil), "a", nil},
{"nil pointer", (*map[string]int)(nil), "a", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MapGet(tt.input, tt.key)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("MapGet() = %v, want %v", got, tt.want)
}
})
}
}
func TestSliceIndex(t *testing.T) {
s := []int{10, 20, 30}
tests := []struct {
name string
slice interface{}
index int
want interface{}
}{
{"index 0", s, 0, 10},
{"index 1", s, 1, 20},
{"index 2", s, 2, 30},
{"negative index", s, -1, nil},
{"out of bounds", s, 5, nil},
{"pointer to slice", &s, 1, 20},
{"not a slice", "test", 0, nil},
{"nil slice", ([]int)(nil), 0, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SliceIndex(tt.slice, tt.index)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SliceIndex() = %v, want %v", got, tt.want)
}
})
}
}
func TestCompareValues(t *testing.T) {
tests := []struct {
name string
a interface{}
b interface{}
want int
}{
{"both nil", nil, nil, 0},
{"a nil", nil, 5, -1},
{"b nil", 5, nil, 1},
{"equal strings", "abc", "abc", 0},
{"a less than b strings", "abc", "xyz", -1},
{"a greater than b strings", "xyz", "abc", 1},
{"equal ints", 5, 5, 0},
{"a less than b ints", 3, 7, -1},
{"a greater than b ints", 10, 5, 1},
{"equal floats", 3.14, 3.14, 0},
{"a less than b floats", 2.5, 5.5, -1},
{"a greater than b floats", 10.5, 5.5, 1},
{"equal uints", uint(5), uint(5), 0},
{"different types", "abc", 123, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := CompareValues(tt.a, tt.b)
if got != tt.want {
t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}
func TestGetNestedValue(t *testing.T) {
nested := map[string]interface{}{
"level1": map[string]interface{}{
"level2": map[string]interface{}{
"value": "deep",
},
},
}
ts := testStruct{
Name: "John",
Nested: &nestedStruct{
Value: "nested value",
Count: 42,
},
}
tests := []struct {
name string
input interface{}
path string
want interface{}
}{
{"empty path", nested, "", nested},
{"single level map", nested, "level1", nested["level1"]},
{"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}},
{"deep nested map", nested, "level1.level2.value", "deep"},
{"struct field", ts, "Name", "John"},
{"nested struct field", ts, "Nested", ts.Nested},
{"non-existent path", nested, "missing.path", nil},
{"nil input", nil, "path", nil},
{"partial missing path", nested, "level1.missing", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetNestedValue(tt.input, tt.path)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetNestedValue() = %v, want %v", got, tt.want)
}
})
}
}
func TestDeepEqual(t *testing.T) {
tests := []struct {
name string
a interface{}
b interface{}
want bool
}{
{"equal ints", 42, 42, true},
{"different ints", 42, 43, false},
{"equal strings", "test", "test", true},
{"different strings", "test", "other", false},
{"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true},
{"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false},
{"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true},
{"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false},
{"both nil", nil, nil, true},
{"one nil", nil, 42, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DeepEqual(tt.a, tt.b)
if got != tt.want {
t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}
// Helper functions
func ptrInt(i int) *int {
return &i
}
func ptrPtr(p *int) **int {
return &p
}