From 5d9770b430771e481859a08b2d3251b396882cfa Mon Sep 17 00:00:00 2001 From: Hein Date: Sat, 31 Jan 2026 22:30:00 +0200 Subject: [PATCH] =?UTF-8?q?test(pgsql,=20reflectutil):=20=E2=9C=A8=20add?= =?UTF-8?q?=20comprehensive=20test=20coverage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduce tests for PostgreSQL data types and keywords. * Implement tests for reflect utility functions. * Ensure consistency and correctness of type conversions and keyword mappings. * Validate behavior for various edge cases and input types. --- pkg/commontypes/commontypes_test.go | 714 ++++++++++++++++++++++++ pkg/diff/diff_test.go | 558 +++++++++++++++++++ pkg/diff/formatters_test.go | 440 +++++++++++++++ pkg/inspector/inspector_test.go | 238 ++++++++ pkg/inspector/report_test.go | 366 ++++++++++++ pkg/inspector/rules_test.go | 249 +++++++++ pkg/inspector/validators_test.go | 837 ++++++++++++++++++++++++++++ pkg/pgsql/datatypes_test.go | 339 +++++++++++ pkg/pgsql/keywords_test.go | 136 +++++ pkg/reflectutil/helpers_test.go | 490 ++++++++++++++++ 10 files changed, 4367 insertions(+) create mode 100644 pkg/commontypes/commontypes_test.go create mode 100644 pkg/diff/diff_test.go create mode 100644 pkg/diff/formatters_test.go create mode 100644 pkg/inspector/inspector_test.go create mode 100644 pkg/inspector/report_test.go create mode 100644 pkg/inspector/rules_test.go create mode 100644 pkg/inspector/validators_test.go create mode 100644 pkg/pgsql/datatypes_test.go create mode 100644 pkg/pgsql/keywords_test.go create mode 100644 pkg/reflectutil/helpers_test.go diff --git a/pkg/commontypes/commontypes_test.go b/pkg/commontypes/commontypes_test.go new file mode 100644 index 0000000..5a0cf7d --- /dev/null +++ b/pkg/commontypes/commontypes_test.go @@ -0,0 +1,714 @@ +package commontypes + +import ( + "testing" +) + +func TestExtractBaseType(t *testing.T) { + tests := []struct { + name string + sqlType string + want string + }{ + {"varchar with length", "varchar(100)", "varchar"}, + {"VARCHAR uppercase with length", "VARCHAR(255)", "varchar"}, + {"numeric with precision", "numeric(10,2)", "numeric"}, + {"NUMERIC uppercase", "NUMERIC(18,4)", "numeric"}, + {"decimal with precision", "decimal(15,3)", "decimal"}, + {"char with length", "char(50)", "char"}, + {"simple integer", "integer", "integer"}, + {"simple text", "text", "text"}, + {"bigint", "bigint", "bigint"}, + {"With spaces", " varchar(100) ", "varchar"}, + {"No parentheses", "boolean", "boolean"}, + {"Empty string", "", ""}, + {"Mixed case", "VarChar(100)", "varchar"}, + {"timestamp with time zone", "timestamp(6) with time zone", "timestamp"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractBaseType(tt.sqlType) + if got != tt.want { + t.Errorf("ExtractBaseType(%q) = %q, want %q", tt.sqlType, got, tt.want) + } + }) + } +} + +func TestNormalizeType(t *testing.T) { + // NormalizeType is an alias for ExtractBaseType, test that they behave the same + testCases := []string{ + "varchar(100)", + "numeric(10,2)", + "integer", + "text", + " VARCHAR(255) ", + } + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + extracted := ExtractBaseType(tc) + normalized := NormalizeType(tc) + if extracted != normalized { + t.Errorf("ExtractBaseType(%q) = %q, but NormalizeType(%q) = %q", + tc, extracted, tc, normalized) + } + }) + } +} + +func TestSQLToGo(t *testing.T) { + tests := []struct { + name string + sqlType string + nullable bool + want string + }{ + // Integer types (nullable) + {"integer nullable", "integer", true, "int32"}, + {"bigint nullable", "bigint", true, "int64"}, + {"smallint nullable", "smallint", true, "int16"}, + {"serial nullable", "serial", true, "int32"}, + + // Integer types (not nullable) + {"integer not nullable", "integer", false, "*int32"}, + {"bigint not nullable", "bigint", false, "*int64"}, + {"smallint not nullable", "smallint", false, "*int16"}, + + // String types (nullable) + {"text nullable", "text", true, "string"}, + {"varchar nullable", "varchar", true, "string"}, + {"varchar with length nullable", "varchar(100)", true, "string"}, + + // String types (not nullable) + {"text not nullable", "text", false, "*string"}, + {"varchar not nullable", "varchar", false, "*string"}, + + // Boolean + {"boolean nullable", "boolean", true, "bool"}, + {"boolean not nullable", "boolean", false, "*bool"}, + + // Float types + {"real nullable", "real", true, "float32"}, + {"double precision nullable", "double precision", true, "float64"}, + {"real not nullable", "real", false, "*float32"}, + {"double precision not nullable", "double precision", false, "*float64"}, + + // Date/Time types + {"timestamp nullable", "timestamp", true, "time.Time"}, + {"date nullable", "date", true, "time.Time"}, + {"timestamp not nullable", "timestamp", false, "*time.Time"}, + + // Binary + {"bytea nullable", "bytea", true, "[]byte"}, + {"bytea not nullable", "bytea", false, "[]byte"}, // Slices don't get pointer + + // UUID + {"uuid nullable", "uuid", true, "string"}, + {"uuid not nullable", "uuid", false, "*string"}, + + // JSON + {"json nullable", "json", true, "string"}, + {"jsonb nullable", "jsonb", true, "string"}, + + // Array + {"array nullable", "array", true, "[]string"}, + {"array not nullable", "array", false, "[]string"}, // Slices don't get pointer + + // Unknown types + {"unknown type nullable", "unknowntype", true, "interface{}"}, + {"unknown type not nullable", "unknowntype", false, "interface{}"}, // Interface doesn't get pointer + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SQLToGo(tt.sqlType, tt.nullable) + if got != tt.want { + t.Errorf("SQLToGo(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want) + } + }) + } +} + +func TestSQLToTypeScript(t *testing.T) { + tests := []struct { + name string + sqlType string + nullable bool + want string + }{ + // Integer types + {"integer nullable", "integer", true, "number"}, + {"integer not nullable", "integer", false, "number | null"}, + {"bigint nullable", "bigint", true, "number"}, + {"bigint not nullable", "bigint", false, "number | null"}, + + // String types + {"text nullable", "text", true, "string"}, + {"text not nullable", "text", false, "string | null"}, + {"varchar nullable", "varchar", true, "string"}, + {"varchar(100) nullable", "varchar(100)", true, "string"}, + + // Boolean + {"boolean nullable", "boolean", true, "boolean"}, + {"boolean not nullable", "boolean", false, "boolean | null"}, + + // Float types + {"real nullable", "real", true, "number"}, + {"double precision nullable", "double precision", true, "number"}, + + // Date/Time types + {"timestamp nullable", "timestamp", true, "Date"}, + {"date nullable", "date", true, "Date"}, + {"timestamp not nullable", "timestamp", false, "Date | null"}, + + // Binary + {"bytea nullable", "bytea", true, "Buffer"}, + {"bytea not nullable", "bytea", false, "Buffer | null"}, + + // JSON + {"json nullable", "json", true, "any"}, + {"jsonb nullable", "jsonb", true, "any"}, + + // UUID + {"uuid nullable", "uuid", true, "string"}, + + // Unknown types + {"unknown type nullable", "unknowntype", true, "any"}, + {"unknown type not nullable", "unknowntype", false, "any | null"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SQLToTypeScript(tt.sqlType, tt.nullable) + if got != tt.want { + t.Errorf("SQLToTypeScript(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want) + } + }) + } +} + +func TestSQLToPython(t *testing.T) { + tests := []struct { + name string + sqlType string + want string + }{ + // Integer types + {"integer", "integer", "int"}, + {"bigint", "bigint", "int"}, + {"smallint", "smallint", "int"}, + + // String types + {"text", "text", "str"}, + {"varchar", "varchar", "str"}, + {"varchar(100)", "varchar(100)", "str"}, + + // Boolean + {"boolean", "boolean", "bool"}, + + // Float types + {"real", "real", "float"}, + {"double precision", "double precision", "float"}, + {"numeric", "numeric", "Decimal"}, + {"decimal", "decimal", "Decimal"}, + + // Date/Time types + {"timestamp", "timestamp", "datetime"}, + {"date", "date", "date"}, + {"time", "time", "time"}, + + // Binary + {"bytea", "bytea", "bytes"}, + + // JSON + {"json", "json", "dict"}, + {"jsonb", "jsonb", "dict"}, + + // UUID + {"uuid", "uuid", "UUID"}, + + // Array + {"array", "array", "list"}, + + // Unknown types + {"unknown type", "unknowntype", "Any"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SQLToPython(tt.sqlType) + if got != tt.want { + t.Errorf("SQLToPython(%q) = %q, want %q", tt.sqlType, got, tt.want) + } + }) + } +} + +func TestSQLToCSharp(t *testing.T) { + tests := []struct { + name string + sqlType string + nullable bool + want string + }{ + // Integer types (nullable) + {"integer nullable", "integer", true, "int"}, + {"bigint nullable", "bigint", true, "long"}, + {"smallint nullable", "smallint", true, "short"}, + + // Integer types (not nullable - value types get ?) + {"integer not nullable", "integer", false, "int?"}, + {"bigint not nullable", "bigint", false, "long?"}, + {"smallint not nullable", "smallint", false, "short?"}, + + // String types (reference types, no ? needed) + {"text nullable", "text", true, "string"}, + {"text not nullable", "text", false, "string"}, + {"varchar nullable", "varchar", true, "string"}, + {"varchar(100) nullable", "varchar(100)", true, "string"}, + + // Boolean + {"boolean nullable", "boolean", true, "bool"}, + {"boolean not nullable", "boolean", false, "bool?"}, + + // Float types + {"real nullable", "real", true, "float"}, + {"double precision nullable", "double precision", true, "double"}, + {"decimal nullable", "decimal", true, "decimal"}, + {"real not nullable", "real", false, "float?"}, + {"double precision not nullable", "double precision", false, "double?"}, + {"decimal not nullable", "decimal", false, "decimal?"}, + + // Date/Time types + {"timestamp nullable", "timestamp", true, "DateTime"}, + {"date nullable", "date", true, "DateTime"}, + {"timestamptz nullable", "timestamptz", true, "DateTimeOffset"}, + {"timestamp not nullable", "timestamp", false, "DateTime?"}, + {"timestamptz not nullable", "timestamptz", false, "DateTimeOffset?"}, + + // Binary (array type, no ?) + {"bytea nullable", "bytea", true, "byte[]"}, + {"bytea not nullable", "bytea", false, "byte[]"}, + + // UUID + {"uuid nullable", "uuid", true, "Guid"}, + {"uuid not nullable", "uuid", false, "Guid?"}, + + // JSON + {"json nullable", "json", true, "string"}, + + // Unknown types (object is reference type) + {"unknown type nullable", "unknowntype", true, "object"}, + {"unknown type not nullable", "unknowntype", false, "object"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SQLToCSharp(tt.sqlType, tt.nullable) + if got != tt.want { + t.Errorf("SQLToCSharp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want) + } + }) + } +} + +func TestNeedsTimeImport(t *testing.T) { + tests := []struct { + name string + goType string + want bool + }{ + {"time.Time type", "time.Time", true}, + {"pointer to time.Time", "*time.Time", true}, + {"int32 type", "int32", false}, + {"string type", "string", false}, + {"bool type", "bool", false}, + {"[]byte type", "[]byte", false}, + {"interface{}", "interface{}", false}, + {"empty string", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NeedsTimeImport(tt.goType) + if got != tt.want { + t.Errorf("NeedsTimeImport(%q) = %v, want %v", tt.goType, got, tt.want) + } + }) + } +} + +func TestGoTypeMap(t *testing.T) { + // Test that the map contains expected entries + expectedMappings := map[string]string{ + "integer": "int32", + "bigint": "int64", + "text": "string", + "boolean": "bool", + "double precision": "float64", + "bytea": "[]byte", + "timestamp": "time.Time", + "uuid": "string", + "json": "string", + } + + for sqlType, expectedGoType := range expectedMappings { + if goType, ok := GoTypeMap[sqlType]; !ok { + t.Errorf("GoTypeMap missing entry for %q", sqlType) + } else if goType != expectedGoType { + t.Errorf("GoTypeMap[%q] = %q, want %q", sqlType, goType, expectedGoType) + } + } + + if len(GoTypeMap) == 0 { + t.Error("GoTypeMap is empty") + } +} + +func TestTypeScriptTypeMap(t *testing.T) { + expectedMappings := map[string]string{ + "integer": "number", + "bigint": "number", + "text": "string", + "boolean": "boolean", + "double precision": "number", + "bytea": "Buffer", + "timestamp": "Date", + "uuid": "string", + "json": "any", + } + + for sqlType, expectedTSType := range expectedMappings { + if tsType, ok := TypeScriptTypeMap[sqlType]; !ok { + t.Errorf("TypeScriptTypeMap missing entry for %q", sqlType) + } else if tsType != expectedTSType { + t.Errorf("TypeScriptTypeMap[%q] = %q, want %q", sqlType, tsType, expectedTSType) + } + } + + if len(TypeScriptTypeMap) == 0 { + t.Error("TypeScriptTypeMap is empty") + } +} + +func TestPythonTypeMap(t *testing.T) { + expectedMappings := map[string]string{ + "integer": "int", + "bigint": "int", + "text": "str", + "boolean": "bool", + "real": "float", + "numeric": "Decimal", + "bytea": "bytes", + "date": "date", + "uuid": "UUID", + "json": "dict", + } + + for sqlType, expectedPyType := range expectedMappings { + if pyType, ok := PythonTypeMap[sqlType]; !ok { + t.Errorf("PythonTypeMap missing entry for %q", sqlType) + } else if pyType != expectedPyType { + t.Errorf("PythonTypeMap[%q] = %q, want %q", sqlType, pyType, expectedPyType) + } + } + + if len(PythonTypeMap) == 0 { + t.Error("PythonTypeMap is empty") + } +} + +func TestCSharpTypeMap(t *testing.T) { + expectedMappings := map[string]string{ + "integer": "int", + "bigint": "long", + "smallint": "short", + "text": "string", + "boolean": "bool", + "double precision": "double", + "decimal": "decimal", + "bytea": "byte[]", + "timestamp": "DateTime", + "uuid": "Guid", + "json": "string", + } + + for sqlType, expectedCSType := range expectedMappings { + if csType, ok := CSharpTypeMap[sqlType]; !ok { + t.Errorf("CSharpTypeMap missing entry for %q", sqlType) + } else if csType != expectedCSType { + t.Errorf("CSharpTypeMap[%q] = %q, want %q", sqlType, csType, expectedCSType) + } + } + + if len(CSharpTypeMap) == 0 { + t.Error("CSharpTypeMap is empty") + } +} + +func TestSQLToJava(t *testing.T) { + tests := []struct { + name string + sqlType string + nullable bool + want string + }{ + // Integer types + {"integer nullable", "integer", true, "Integer"}, + {"integer not nullable", "integer", false, "Integer"}, + {"bigint nullable", "bigint", true, "Long"}, + {"smallint nullable", "smallint", true, "Short"}, + + // String types + {"text nullable", "text", true, "String"}, + {"varchar nullable", "varchar", true, "String"}, + {"varchar(100) nullable", "varchar(100)", true, "String"}, + + // Boolean + {"boolean nullable", "boolean", true, "Boolean"}, + + // Float types + {"real nullable", "real", true, "Float"}, + {"double precision nullable", "double precision", true, "Double"}, + {"numeric nullable", "numeric", true, "BigDecimal"}, + + // Date/Time types + {"timestamp nullable", "timestamp", true, "Timestamp"}, + {"date nullable", "date", true, "Date"}, + {"time nullable", "time", true, "Time"}, + + // Binary + {"bytea nullable", "bytea", true, "byte[]"}, + + // UUID + {"uuid nullable", "uuid", true, "UUID"}, + + // JSON + {"json nullable", "json", true, "String"}, + + // Unknown types + {"unknown type nullable", "unknowntype", true, "Object"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SQLToJava(tt.sqlType, tt.nullable) + if got != tt.want { + t.Errorf("SQLToJava(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want) + } + }) + } +} + +func TestSQLToPhp(t *testing.T) { + tests := []struct { + name string + sqlType string + nullable bool + want string + }{ + // Integer types (nullable) + {"integer nullable", "integer", true, "int"}, + {"bigint nullable", "bigint", true, "int"}, + {"smallint nullable", "smallint", true, "int"}, + + // Integer types (not nullable) + {"integer not nullable", "integer", false, "?int"}, + {"bigint not nullable", "bigint", false, "?int"}, + + // String types + {"text nullable", "text", true, "string"}, + {"text not nullable", "text", false, "?string"}, + {"varchar nullable", "varchar", true, "string"}, + {"varchar(100) nullable", "varchar(100)", true, "string"}, + + // Boolean + {"boolean nullable", "boolean", true, "bool"}, + {"boolean not nullable", "boolean", false, "?bool"}, + + // Float types + {"real nullable", "real", true, "float"}, + {"double precision nullable", "double precision", true, "float"}, + {"real not nullable", "real", false, "?float"}, + + // Date/Time types + {"timestamp nullable", "timestamp", true, "\\DateTime"}, + {"date nullable", "date", true, "\\DateTime"}, + {"timestamp not nullable", "timestamp", false, "?\\DateTime"}, + + // Binary + {"bytea nullable", "bytea", true, "string"}, + {"bytea not nullable", "bytea", false, "?string"}, + + // JSON + {"json nullable", "json", true, "array"}, + {"json not nullable", "json", false, "?array"}, + + // UUID + {"uuid nullable", "uuid", true, "string"}, + + // Unknown types + {"unknown type nullable", "unknowntype", true, "mixed"}, + {"unknown type not nullable", "unknowntype", false, "mixed"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SQLToPhp(tt.sqlType, tt.nullable) + if got != tt.want { + t.Errorf("SQLToPhp(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want) + } + }) + } +} + +func TestSQLToRust(t *testing.T) { + tests := []struct { + name string + sqlType string + nullable bool + want string + }{ + // Integer types (nullable) + {"integer nullable", "integer", true, "i32"}, + {"bigint nullable", "bigint", true, "i64"}, + {"smallint nullable", "smallint", true, "i16"}, + + // Integer types (not nullable) + {"integer not nullable", "integer", false, "Option"}, + {"bigint not nullable", "bigint", false, "Option"}, + {"smallint not nullable", "smallint", false, "Option"}, + + // String types + {"text nullable", "text", true, "String"}, + {"text not nullable", "text", false, "Option"}, + {"varchar nullable", "varchar", true, "String"}, + {"varchar(100) nullable", "varchar(100)", true, "String"}, + + // Boolean + {"boolean nullable", "boolean", true, "bool"}, + {"boolean not nullable", "boolean", false, "Option"}, + + // Float types + {"real nullable", "real", true, "f32"}, + {"double precision nullable", "double precision", true, "f64"}, + {"real not nullable", "real", false, "Option"}, + {"double precision not nullable", "double precision", false, "Option"}, + + // Date/Time types + {"timestamp nullable", "timestamp", true, "NaiveDateTime"}, + {"timestamptz nullable", "timestamptz", true, "DateTime"}, + {"date nullable", "date", true, "NaiveDate"}, + {"time nullable", "time", true, "NaiveTime"}, + {"timestamp not nullable", "timestamp", false, "Option"}, + + // Binary + {"bytea nullable", "bytea", true, "Vec"}, + {"bytea not nullable", "bytea", false, "Option>"}, + + // JSON + {"json nullable", "json", true, "serde_json::Value"}, + {"json not nullable", "json", false, "Option"}, + + // UUID + {"uuid nullable", "uuid", true, "String"}, + + // Unknown types + {"unknown type nullable", "unknowntype", true, "String"}, + {"unknown type not nullable", "unknowntype", false, "Option"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SQLToRust(tt.sqlType, tt.nullable) + if got != tt.want { + t.Errorf("SQLToRust(%q, %v) = %q, want %q", tt.sqlType, tt.nullable, got, tt.want) + } + }) + } +} + +func TestJavaTypeMap(t *testing.T) { + expectedMappings := map[string]string{ + "integer": "Integer", + "bigint": "Long", + "smallint": "Short", + "text": "String", + "boolean": "Boolean", + "double precision": "Double", + "numeric": "BigDecimal", + "bytea": "byte[]", + "timestamp": "Timestamp", + "uuid": "UUID", + "date": "Date", + } + + for sqlType, expectedJavaType := range expectedMappings { + if javaType, ok := JavaTypeMap[sqlType]; !ok { + t.Errorf("JavaTypeMap missing entry for %q", sqlType) + } else if javaType != expectedJavaType { + t.Errorf("JavaTypeMap[%q] = %q, want %q", sqlType, javaType, expectedJavaType) + } + } + + if len(JavaTypeMap) == 0 { + t.Error("JavaTypeMap is empty") + } +} + +func TestPHPTypeMap(t *testing.T) { + expectedMappings := map[string]string{ + "integer": "int", + "bigint": "int", + "text": "string", + "boolean": "bool", + "double precision": "float", + "bytea": "string", + "timestamp": "\\DateTime", + "uuid": "string", + "json": "array", + } + + for sqlType, expectedPHPType := range expectedMappings { + if phpType, ok := PHPTypeMap[sqlType]; !ok { + t.Errorf("PHPTypeMap missing entry for %q", sqlType) + } else if phpType != expectedPHPType { + t.Errorf("PHPTypeMap[%q] = %q, want %q", sqlType, phpType, expectedPHPType) + } + } + + if len(PHPTypeMap) == 0 { + t.Error("PHPTypeMap is empty") + } +} + +func TestRustTypeMap(t *testing.T) { + expectedMappings := map[string]string{ + "integer": "i32", + "bigint": "i64", + "smallint": "i16", + "text": "String", + "boolean": "bool", + "double precision": "f64", + "real": "f32", + "bytea": "Vec", + "timestamp": "NaiveDateTime", + "timestamptz": "DateTime", + "date": "NaiveDate", + "json": "serde_json::Value", + } + + for sqlType, expectedRustType := range expectedMappings { + if rustType, ok := RustTypeMap[sqlType]; !ok { + t.Errorf("RustTypeMap missing entry for %q", sqlType) + } else if rustType != expectedRustType { + t.Errorf("RustTypeMap[%q] = %q, want %q", sqlType, rustType, expectedRustType) + } + } + + if len(RustTypeMap) == 0 { + t.Error("RustTypeMap is empty") + } +} diff --git a/pkg/diff/diff_test.go b/pkg/diff/diff_test.go new file mode 100644 index 0000000..1d54ab7 --- /dev/null +++ b/pkg/diff/diff_test.go @@ -0,0 +1,558 @@ +package diff + +import ( + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" +) + +func TestCompareDatabases(t *testing.T) { + tests := []struct { + name string + source *models.Database + target *models.Database + want func(*DiffResult) bool + }{ + { + name: "identical databases", + source: &models.Database{ + Name: "source", + Schemas: []*models.Schema{}, + }, + target: &models.Database{ + Name: "target", + Schemas: []*models.Schema{}, + }, + want: func(r *DiffResult) bool { + return r.Source == "source" && r.Target == "target" && + len(r.Schemas.Missing) == 0 && len(r.Schemas.Extra) == 0 + }, + }, + { + name: "different schemas", + source: &models.Database{ + Name: "source", + Schemas: []*models.Schema{ + {Name: "public"}, + }, + }, + target: &models.Database{ + Name: "target", + Schemas: []*models.Schema{}, + }, + want: func(r *DiffResult) bool { + return len(r.Schemas.Missing) == 1 && r.Schemas.Missing[0].Name == "public" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CompareDatabases(tt.source, tt.target) + if !tt.want(got) { + t.Errorf("CompareDatabases() result doesn't match expectations") + } + }) + } +} + +func TestCompareColumns(t *testing.T) { + tests := []struct { + name string + source map[string]*models.Column + target map[string]*models.Column + want func(*ColumnDiff) bool + }{ + { + name: "identical columns", + source: map[string]*models.Column{}, + target: map[string]*models.Column{}, + want: func(d *ColumnDiff) bool { + return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0 + }, + }, + { + name: "missing column", + source: map[string]*models.Column{ + "id": {Name: "id", Type: "integer"}, + }, + target: map[string]*models.Column{}, + want: func(d *ColumnDiff) bool { + return len(d.Missing) == 1 && d.Missing[0].Name == "id" + }, + }, + { + name: "extra column", + source: map[string]*models.Column{}, + target: map[string]*models.Column{ + "id": {Name: "id", Type: "integer"}, + }, + want: func(d *ColumnDiff) bool { + return len(d.Extra) == 1 && d.Extra[0].Name == "id" + }, + }, + { + name: "modified column type", + source: map[string]*models.Column{ + "id": {Name: "id", Type: "integer"}, + }, + target: map[string]*models.Column{ + "id": {Name: "id", Type: "bigint"}, + }, + want: func(d *ColumnDiff) bool { + return len(d.Modified) == 1 && d.Modified[0].Name == "id" && + d.Modified[0].Changes["type"] != nil + }, + }, + { + name: "modified column nullable", + source: map[string]*models.Column{ + "name": {Name: "name", Type: "text", NotNull: true}, + }, + target: map[string]*models.Column{ + "name": {Name: "name", Type: "text", NotNull: false}, + }, + want: func(d *ColumnDiff) bool { + return len(d.Modified) == 1 && d.Modified[0].Changes["not_null"] != nil + }, + }, + { + name: "modified column length", + source: map[string]*models.Column{ + "name": {Name: "name", Type: "varchar", Length: 100}, + }, + target: map[string]*models.Column{ + "name": {Name: "name", Type: "varchar", Length: 255}, + }, + want: func(d *ColumnDiff) bool { + return len(d.Modified) == 1 && d.Modified[0].Changes["length"] != nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := compareColumns(tt.source, tt.target) + if !tt.want(got) { + t.Errorf("compareColumns() result doesn't match expectations") + } + }) + } +} + +func TestCompareColumnDetails(t *testing.T) { + tests := []struct { + name string + source *models.Column + target *models.Column + want int // number of changes + }{ + { + name: "identical columns", + source: &models.Column{Name: "id", Type: "integer"}, + target: &models.Column{Name: "id", Type: "integer"}, + want: 0, + }, + { + name: "type change", + source: &models.Column{Name: "id", Type: "integer"}, + target: &models.Column{Name: "id", Type: "bigint"}, + want: 1, + }, + { + name: "length change", + source: &models.Column{Name: "name", Type: "varchar", Length: 100}, + target: &models.Column{Name: "name", Type: "varchar", Length: 255}, + want: 1, + }, + { + name: "precision change", + source: &models.Column{Name: "price", Type: "numeric", Precision: 10}, + target: &models.Column{Name: "price", Type: "numeric", Precision: 12}, + want: 1, + }, + { + name: "scale change", + source: &models.Column{Name: "price", Type: "numeric", Scale: 2}, + target: &models.Column{Name: "price", Type: "numeric", Scale: 4}, + want: 1, + }, + { + name: "not null change", + source: &models.Column{Name: "name", Type: "text", NotNull: true}, + target: &models.Column{Name: "name", Type: "text", NotNull: false}, + want: 1, + }, + { + name: "auto increment change", + source: &models.Column{Name: "id", Type: "integer", AutoIncrement: true}, + target: &models.Column{Name: "id", Type: "integer", AutoIncrement: false}, + want: 1, + }, + { + name: "primary key change", + source: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: true}, + target: &models.Column{Name: "id", Type: "integer", IsPrimaryKey: false}, + want: 1, + }, + { + name: "multiple changes", + source: &models.Column{Name: "id", Type: "integer", NotNull: true, AutoIncrement: true}, + target: &models.Column{Name: "id", Type: "bigint", NotNull: false, AutoIncrement: false}, + want: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := compareColumnDetails(tt.source, tt.target) + if len(got) != tt.want { + t.Errorf("compareColumnDetails() = %d changes, want %d", len(got), tt.want) + } + }) + } +} + +func TestCompareIndexes(t *testing.T) { + tests := []struct { + name string + source map[string]*models.Index + target map[string]*models.Index + want func(*IndexDiff) bool + }{ + { + name: "identical indexes", + source: map[string]*models.Index{}, + target: map[string]*models.Index{}, + want: func(d *IndexDiff) bool { + return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0 + }, + }, + { + name: "missing index", + source: map[string]*models.Index{ + "idx_name": {Name: "idx_name", Columns: []string{"name"}}, + }, + target: map[string]*models.Index{}, + want: func(d *IndexDiff) bool { + return len(d.Missing) == 1 && d.Missing[0].Name == "idx_name" + }, + }, + { + name: "extra index", + source: map[string]*models.Index{}, + target: map[string]*models.Index{ + "idx_name": {Name: "idx_name", Columns: []string{"name"}}, + }, + want: func(d *IndexDiff) bool { + return len(d.Extra) == 1 && d.Extra[0].Name == "idx_name" + }, + }, + { + name: "modified index uniqueness", + source: map[string]*models.Index{ + "idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: false}, + }, + target: map[string]*models.Index{ + "idx_name": {Name: "idx_name", Columns: []string{"name"}, Unique: true}, + }, + want: func(d *IndexDiff) bool { + return len(d.Modified) == 1 && d.Modified[0].Name == "idx_name" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := compareIndexes(tt.source, tt.target) + if !tt.want(got) { + t.Errorf("compareIndexes() result doesn't match expectations") + } + }) + } +} + +func TestCompareConstraints(t *testing.T) { + tests := []struct { + name string + source map[string]*models.Constraint + target map[string]*models.Constraint + want func(*ConstraintDiff) bool + }{ + { + name: "identical constraints", + source: map[string]*models.Constraint{}, + target: map[string]*models.Constraint{}, + want: func(d *ConstraintDiff) bool { + return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0 + }, + }, + { + name: "missing constraint", + source: map[string]*models.Constraint{ + "pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}}, + }, + target: map[string]*models.Constraint{}, + want: func(d *ConstraintDiff) bool { + return len(d.Missing) == 1 && d.Missing[0].Name == "pk_id" + }, + }, + { + name: "extra constraint", + source: map[string]*models.Constraint{}, + target: map[string]*models.Constraint{ + "pk_id": {Name: "pk_id", Type: "PRIMARY KEY", Columns: []string{"id"}}, + }, + want: func(d *ConstraintDiff) bool { + return len(d.Extra) == 1 && d.Extra[0].Name == "pk_id" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := compareConstraints(tt.source, tt.target) + if !tt.want(got) { + t.Errorf("compareConstraints() result doesn't match expectations") + } + }) + } +} + +func TestCompareRelationships(t *testing.T) { + tests := []struct { + name string + source map[string]*models.Relationship + target map[string]*models.Relationship + want func(*RelationshipDiff) bool + }{ + { + name: "identical relationships", + source: map[string]*models.Relationship{}, + target: map[string]*models.Relationship{}, + want: func(d *RelationshipDiff) bool { + return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0 + }, + }, + { + name: "missing relationship", + source: map[string]*models.Relationship{ + "fk_user": {Name: "fk_user", Type: "FOREIGN KEY"}, + }, + target: map[string]*models.Relationship{}, + want: func(d *RelationshipDiff) bool { + return len(d.Missing) == 1 && d.Missing[0].Name == "fk_user" + }, + }, + { + name: "extra relationship", + source: map[string]*models.Relationship{}, + target: map[string]*models.Relationship{ + "fk_user": {Name: "fk_user", Type: "FOREIGN KEY"}, + }, + want: func(d *RelationshipDiff) bool { + return len(d.Extra) == 1 && d.Extra[0].Name == "fk_user" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := compareRelationships(tt.source, tt.target) + if !tt.want(got) { + t.Errorf("compareRelationships() result doesn't match expectations") + } + }) + } +} + +func TestCompareTables(t *testing.T) { + tests := []struct { + name string + source []*models.Table + target []*models.Table + want func(*TableDiff) bool + }{ + { + name: "identical tables", + source: []*models.Table{}, + target: []*models.Table{}, + want: func(d *TableDiff) bool { + return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0 + }, + }, + { + name: "missing table", + source: []*models.Table{ + {Name: "users", Schema: "public"}, + }, + target: []*models.Table{}, + want: func(d *TableDiff) bool { + return len(d.Missing) == 1 && d.Missing[0].Name == "users" + }, + }, + { + name: "extra table", + source: []*models.Table{}, + target: []*models.Table{ + {Name: "users", Schema: "public"}, + }, + want: func(d *TableDiff) bool { + return len(d.Extra) == 1 && d.Extra[0].Name == "users" + }, + }, + { + name: "modified table", + source: []*models.Table{ + { + Name: "users", + Schema: "public", + Columns: map[string]*models.Column{ + "id": {Name: "id", Type: "integer"}, + }, + }, + }, + target: []*models.Table{ + { + Name: "users", + Schema: "public", + Columns: map[string]*models.Column{ + "id": {Name: "id", Type: "bigint"}, + }, + }, + }, + want: func(d *TableDiff) bool { + return len(d.Modified) == 1 && d.Modified[0].Name == "users" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := compareTables(tt.source, tt.target) + if !tt.want(got) { + t.Errorf("compareTables() result doesn't match expectations") + } + }) + } +} + +func TestCompareSchemas(t *testing.T) { + tests := []struct { + name string + source []*models.Schema + target []*models.Schema + want func(*SchemaDiff) bool + }{ + { + name: "identical schemas", + source: []*models.Schema{}, + target: []*models.Schema{}, + want: func(d *SchemaDiff) bool { + return len(d.Missing) == 0 && len(d.Extra) == 0 && len(d.Modified) == 0 + }, + }, + { + name: "missing schema", + source: []*models.Schema{ + {Name: "public"}, + }, + target: []*models.Schema{}, + want: func(d *SchemaDiff) bool { + return len(d.Missing) == 1 && d.Missing[0].Name == "public" + }, + }, + { + name: "extra schema", + source: []*models.Schema{}, + target: []*models.Schema{ + {Name: "public"}, + }, + want: func(d *SchemaDiff) bool { + return len(d.Extra) == 1 && d.Extra[0].Name == "public" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := compareSchemas(tt.source, tt.target) + if !tt.want(got) { + t.Errorf("compareSchemas() result doesn't match expectations") + } + }) + } +} + +func TestIsEmpty(t *testing.T) { + tests := []struct { + name string + v interface{} + want bool + }{ + {"empty ColumnDiff", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, true}, + {"ColumnDiff with missing", &ColumnDiff{Missing: []*models.Column{{Name: "id"}}, Extra: []*models.Column{}, Modified: []*ColumnChange{}}, false}, + {"ColumnDiff with extra", &ColumnDiff{Missing: []*models.Column{}, Extra: []*models.Column{{Name: "id"}}, Modified: []*ColumnChange{}}, false}, + {"empty IndexDiff", &IndexDiff{Missing: []*models.Index{}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, true}, + {"IndexDiff with missing", &IndexDiff{Missing: []*models.Index{{Name: "idx"}}, Extra: []*models.Index{}, Modified: []*IndexChange{}}, false}, + {"empty TableDiff", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{}, Modified: []*TableChange{}}, true}, + {"TableDiff with extra", &TableDiff{Missing: []*models.Table{}, Extra: []*models.Table{{Name: "users"}}, Modified: []*TableChange{}}, false}, + {"empty ConstraintDiff", &ConstraintDiff{Missing: []*models.Constraint{}, Extra: []*models.Constraint{}, Modified: []*ConstraintChange{}}, true}, + {"empty RelationshipDiff", &RelationshipDiff{Missing: []*models.Relationship{}, Extra: []*models.Relationship{}, Modified: []*RelationshipChange{}}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isEmpty(tt.v) + if got != tt.want { + t.Errorf("isEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestComputeSummary(t *testing.T) { + tests := []struct { + name string + result *DiffResult + want func(*Summary) bool + }{ + { + name: "empty diff", + result: &DiffResult{ + Schemas: &SchemaDiff{ + Missing: []*models.Schema{}, + Extra: []*models.Schema{}, + Modified: []*SchemaChange{}, + }, + }, + want: func(s *Summary) bool { + return s.Schemas.Missing == 0 && s.Schemas.Extra == 0 && s.Schemas.Modified == 0 + }, + }, + { + name: "schemas with differences", + result: &DiffResult{ + Schemas: &SchemaDiff{ + Missing: []*models.Schema{{Name: "schema1"}}, + Extra: []*models.Schema{{Name: "schema2"}, {Name: "schema3"}}, + Modified: []*SchemaChange{ + {Name: "public"}, + }, + }, + }, + want: func(s *Summary) bool { + return s.Schemas.Missing == 1 && s.Schemas.Extra == 2 && s.Schemas.Modified == 1 + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ComputeSummary(tt.result) + if !tt.want(got) { + t.Errorf("ComputeSummary() result doesn't match expectations") + } + }) + } +} diff --git a/pkg/diff/formatters_test.go b/pkg/diff/formatters_test.go new file mode 100644 index 0000000..bde749b --- /dev/null +++ b/pkg/diff/formatters_test.go @@ -0,0 +1,440 @@ +package diff + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" +) + +func TestFormatDiff(t *testing.T) { + result := &DiffResult{ + Source: "source_db", + Target: "target_db", + Schemas: &SchemaDiff{ + Missing: []*models.Schema{}, + Extra: []*models.Schema{}, + Modified: []*SchemaChange{}, + }, + } + + tests := []struct { + name string + format OutputFormat + wantErr bool + }{ + {"summary format", FormatSummary, false}, + {"json format", FormatJSON, false}, + {"html format", FormatHTML, false}, + {"invalid format", OutputFormat("invalid"), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + err := FormatDiff(result, tt.format, &buf) + + if (err != nil) != tt.wantErr { + t.Errorf("FormatDiff() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && buf.Len() == 0 { + t.Error("FormatDiff() produced empty output") + } + }) + } +} + +func TestFormatSummary(t *testing.T) { + tests := []struct { + name string + result *DiffResult + wantStr []string // strings that should appear in output + }{ + { + name: "no differences", + result: &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Missing: []*models.Schema{}, + Extra: []*models.Schema{}, + Modified: []*SchemaChange{}, + }, + }, + wantStr: []string{"source", "target", "No differences found"}, + }, + { + name: "with schema differences", + result: &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Missing: []*models.Schema{{Name: "schema1"}}, + Extra: []*models.Schema{{Name: "schema2"}}, + Modified: []*SchemaChange{ + {Name: "public"}, + }, + }, + }, + wantStr: []string{"Schemas:", "Missing: 1", "Extra: 1", "Modified: 1"}, + }, + { + name: "with table differences", + result: &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Modified: []*SchemaChange{ + { + Name: "public", + Tables: &TableDiff{ + Missing: []*models.Table{{Name: "users"}}, + Extra: []*models.Table{{Name: "posts"}}, + Modified: []*TableChange{ + {Name: "comments", Schema: "public"}, + }, + }, + }, + }, + }, + }, + wantStr: []string{"Tables:", "Missing: 1", "Extra: 1", "Modified: 1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + err := formatSummary(tt.result, &buf) + + if err != nil { + t.Errorf("formatSummary() error = %v", err) + return + } + + output := buf.String() + for _, want := range tt.wantStr { + if !strings.Contains(output, want) { + t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output) + } + } + }) + } +} + +func TestFormatJSON(t *testing.T) { + result := &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Missing: []*models.Schema{{Name: "schema1"}}, + Extra: []*models.Schema{}, + Modified: []*SchemaChange{}, + }, + } + + var buf bytes.Buffer + err := formatJSON(result, &buf) + + if err != nil { + t.Errorf("formatJSON() error = %v", err) + return + } + + // Check if output is valid JSON + var decoded DiffResult + if err := json.Unmarshal(buf.Bytes(), &decoded); err != nil { + t.Errorf("formatJSON() produced invalid JSON: %v", err) + } + + // Check basic structure + if decoded.Source != "source" { + t.Errorf("formatJSON() source = %v, want %v", decoded.Source, "source") + } + if decoded.Target != "target" { + t.Errorf("formatJSON() target = %v, want %v", decoded.Target, "target") + } + if len(decoded.Schemas.Missing) != 1 { + t.Errorf("formatJSON() missing schemas = %v, want 1", len(decoded.Schemas.Missing)) + } +} + +func TestFormatHTML(t *testing.T) { + tests := []struct { + name string + result *DiffResult + wantStr []string // HTML elements/content that should appear + }{ + { + name: "basic HTML structure", + result: &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Missing: []*models.Schema{}, + Extra: []*models.Schema{}, + Modified: []*SchemaChange{}, + }, + }, + wantStr: []string{ + "", + "Database Diff Report", + "source", + "target", + }, + }, + { + name: "with schema differences", + result: &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Missing: []*models.Schema{{Name: "missing_schema"}}, + Extra: []*models.Schema{{Name: "extra_schema"}}, + Modified: []*SchemaChange{}, + }, + }, + wantStr: []string{ + "", + "missing_schema", + "extra_schema", + "MISSING", + "EXTRA", + }, + }, + { + name: "with table modifications", + result: &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Modified: []*SchemaChange{ + { + Name: "public", + Tables: &TableDiff{ + Modified: []*TableChange{ + { + Name: "users", + Schema: "public", + Columns: &ColumnDiff{ + Missing: []*models.Column{{Name: "email", Type: "text"}}, + }, + }, + }, + }, + }, + }, + }, + }, + wantStr: []string{ + "public", + "users", + "email", + "text", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + err := formatHTML(tt.result, &buf) + + if err != nil { + t.Errorf("formatHTML() error = %v", err) + return + } + + output := buf.String() + for _, want := range tt.wantStr { + if !strings.Contains(output, want) { + t.Errorf("formatHTML() output doesn't contain %q", want) + } + } + }) + } +} + +func TestFormatSummaryWithColumns(t *testing.T) { + result := &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Modified: []*SchemaChange{ + { + Name: "public", + Tables: &TableDiff{ + Modified: []*TableChange{ + { + Name: "users", + Schema: "public", + Columns: &ColumnDiff{ + Missing: []*models.Column{{Name: "email"}}, + Extra: []*models.Column{{Name: "phone"}, {Name: "address"}}, + Modified: []*ColumnChange{ + {Name: "name"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + var buf bytes.Buffer + err := formatSummary(result, &buf) + + if err != nil { + t.Errorf("formatSummary() error = %v", err) + return + } + + output := buf.String() + wantStrings := []string{ + "Columns:", + "Missing: 1", + "Extra: 2", + "Modified: 1", + } + + for _, want := range wantStrings { + if !strings.Contains(output, want) { + t.Errorf("formatSummary() output doesn't contain %q\nGot: %s", want, output) + } + } +} + +func TestFormatSummaryWithIndexes(t *testing.T) { + result := &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Modified: []*SchemaChange{ + { + Name: "public", + Tables: &TableDiff{ + Modified: []*TableChange{ + { + Name: "users", + Schema: "public", + Indexes: &IndexDiff{ + Missing: []*models.Index{{Name: "idx_email"}}, + Extra: []*models.Index{{Name: "idx_phone"}}, + Modified: []*IndexChange{{Name: "idx_name"}}, + }, + }, + }, + }, + }, + }, + }, + } + + var buf bytes.Buffer + err := formatSummary(result, &buf) + + if err != nil { + t.Errorf("formatSummary() error = %v", err) + return + } + + output := buf.String() + if !strings.Contains(output, "Indexes:") { + t.Error("formatSummary() output doesn't contain Indexes section") + } + if !strings.Contains(output, "Missing: 1") { + t.Error("formatSummary() output doesn't contain correct missing count") + } +} + +func TestFormatSummaryWithConstraints(t *testing.T) { + result := &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Modified: []*SchemaChange{ + { + Name: "public", + Tables: &TableDiff{ + Modified: []*TableChange{ + { + Name: "users", + Schema: "public", + Constraints: &ConstraintDiff{ + Missing: []*models.Constraint{{Name: "pk_users", Type: "PRIMARY KEY"}}, + Extra: []*models.Constraint{{Name: "fk_users_roles", Type: "FOREIGN KEY"}}, + }, + }, + }, + }, + }, + }, + }, + } + + var buf bytes.Buffer + err := formatSummary(result, &buf) + + if err != nil { + t.Errorf("formatSummary() error = %v", err) + return + } + + output := buf.String() + if !strings.Contains(output, "Constraints:") { + t.Error("formatSummary() output doesn't contain Constraints section") + } +} + +func TestFormatJSONIndentation(t *testing.T) { + result := &DiffResult{ + Source: "source", + Target: "target", + Schemas: &SchemaDiff{ + Missing: []*models.Schema{{Name: "test"}}, + }, + } + + var buf bytes.Buffer + err := formatJSON(result, &buf) + + if err != nil { + t.Errorf("formatJSON() error = %v", err) + return + } + + // Check that JSON is indented (has newlines and spaces) + output := buf.String() + if !strings.Contains(output, "\n") { + t.Error("formatJSON() should produce indented JSON with newlines") + } + if !strings.Contains(output, " ") { + t.Error("formatJSON() should produce indented JSON with spaces") + } +} + +func TestOutputFormatConstants(t *testing.T) { + tests := []struct { + name string + format OutputFormat + want string + }{ + {"summary constant", FormatSummary, "summary"}, + {"json constant", FormatJSON, "json"}, + {"html constant", FormatHTML, "html"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if string(tt.format) != tt.want { + t.Errorf("OutputFormat %v = %v, want %v", tt.name, tt.format, tt.want) + } + }) + } +} diff --git a/pkg/inspector/inspector_test.go b/pkg/inspector/inspector_test.go new file mode 100644 index 0000000..016f97c --- /dev/null +++ b/pkg/inspector/inspector_test.go @@ -0,0 +1,238 @@ +package inspector + +import ( + "testing" +) + +func TestNewInspector(t *testing.T) { + db := createTestDatabase() + config := GetDefaultConfig() + + inspector := NewInspector(db, config) + + if inspector == nil { + t.Fatal("NewInspector() returned nil") + } + + if inspector.db != db { + t.Error("NewInspector() database not set correctly") + } + + if inspector.config != config { + t.Error("NewInspector() config not set correctly") + } +} + +func TestInspect(t *testing.T) { + db := createTestDatabase() + config := GetDefaultConfig() + + inspector := NewInspector(db, config) + report, err := inspector.Inspect() + + if err != nil { + t.Fatalf("Inspect() returned error: %v", err) + } + + if report == nil { + t.Fatal("Inspect() returned nil report") + } + + if report.Database != db.Name { + t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name) + } + + if report.Summary.TotalRules != len(config.Rules) { + t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules)) + } + + if len(report.Violations) == 0 { + t.Error("Inspect() returned no violations, expected some results") + } +} + +func TestInspectWithDisabledRules(t *testing.T) { + db := createTestDatabase() + config := GetDefaultConfig() + + // Disable all rules + for name := range config.Rules { + rule := config.Rules[name] + rule.Enabled = "off" + config.Rules[name] = rule + } + + inspector := NewInspector(db, config) + report, err := inspector.Inspect() + + if err != nil { + t.Fatalf("Inspect() with disabled rules returned error: %v", err) + } + + if report.Summary.RulesChecked != 0 { + t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked) + } + + if report.Summary.RulesSkipped != len(config.Rules) { + t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules)) + } +} + +func TestInspectWithEnforcedRules(t *testing.T) { + db := createTestDatabase() + config := GetDefaultConfig() + + // Enable only one rule and enforce it + for name := range config.Rules { + rule := config.Rules[name] + rule.Enabled = "off" + config.Rules[name] = rule + } + + primaryKeyRule := config.Rules["primary_key_naming"] + primaryKeyRule.Enabled = "enforce" + primaryKeyRule.Pattern = "^id$" + config.Rules["primary_key_naming"] = primaryKeyRule + + inspector := NewInspector(db, config) + report, err := inspector.Inspect() + + if err != nil { + t.Fatalf("Inspect() returned error: %v", err) + } + + if report.Summary.RulesChecked != 1 { + t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked) + } + + // All results should be at error level for enforced rules + for _, violation := range report.Violations { + if violation.Level != "error" { + t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level) + } + } +} + +func TestGenerateSummary(t *testing.T) { + db := createTestDatabase() + config := GetDefaultConfig() + inspector := NewInspector(db, config) + + results := []ValidationResult{ + {RuleName: "rule1", Passed: true, Level: "error"}, + {RuleName: "rule2", Passed: false, Level: "error"}, + {RuleName: "rule3", Passed: false, Level: "warning"}, + {RuleName: "rule4", Passed: true, Level: "warning"}, + } + + summary := inspector.generateSummary(results) + + if summary.PassedCount != 2 { + t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount) + } + + if summary.ErrorCount != 1 { + t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount) + } + + if summary.WarningCount != 1 { + t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount) + } +} + +func TestHasErrors(t *testing.T) { + tests := []struct { + name string + report *InspectorReport + want bool + }{ + { + name: "with errors", + report: &InspectorReport{ + Summary: ReportSummary{ + ErrorCount: 5, + }, + }, + want: true, + }, + { + name: "without errors", + report: &InspectorReport{ + Summary: ReportSummary{ + ErrorCount: 0, + WarningCount: 3, + }, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.report.HasErrors(); got != tt.want { + t.Errorf("HasErrors() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetValidator(t *testing.T) { + tests := []struct { + name string + functionName string + wantExists bool + }{ + {"primary_key_naming", "primary_key_naming", true}, + {"primary_key_datatype", "primary_key_datatype", true}, + {"foreign_key_column_naming", "foreign_key_column_naming", true}, + {"table_regexpr", "table_regexpr", true}, + {"column_regexpr", "column_regexpr", true}, + {"reserved_words", "reserved_words", true}, + {"have_primary_key", "have_primary_key", true}, + {"orphaned_foreign_key", "orphaned_foreign_key", true}, + {"circular_dependency", "circular_dependency", true}, + {"unknown_function", "unknown_function", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, exists := getValidator(tt.functionName) + if exists != tt.wantExists { + t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists) + } + }) + } +} + +func TestCreateResult(t *testing.T) { + result := createResult( + "test_rule", + true, + "Test message", + "schema.table.column", + map[string]interface{}{ + "key1": "value1", + "key2": 42, + }, + ) + + if result.RuleName != "test_rule" { + t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName) + } + + if !result.Passed { + t.Error("createResult() Passed = false, want true") + } + + if result.Message != "Test message" { + t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message) + } + + if result.Location != "schema.table.column" { + t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location) + } + + if len(result.Context) != 2 { + t.Errorf("createResult() Context length = %d, want 2", len(result.Context)) + } +} diff --git a/pkg/inspector/report_test.go b/pkg/inspector/report_test.go new file mode 100644 index 0000000..e4ef7cb --- /dev/null +++ b/pkg/inspector/report_test.go @@ -0,0 +1,366 @@ +package inspector + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + "time" +) + +func createTestReport() *InspectorReport { + return &InspectorReport{ + Summary: ReportSummary{ + TotalRules: 10, + RulesChecked: 8, + RulesSkipped: 2, + ErrorCount: 3, + WarningCount: 5, + PassedCount: 12, + }, + Violations: []ValidationResult{ + { + RuleName: "primary_key_naming", + Level: "error", + Message: "Primary key should start with 'id_'", + Location: "public.users.user_id", + Passed: false, + Context: map[string]interface{}{ + "schema": "public", + "table": "users", + "column": "user_id", + "pattern": "^id_", + }, + }, + { + RuleName: "table_name_length", + Level: "warning", + Message: "Table name too long", + Location: "public.very_long_table_name_that_exceeds_limits", + Passed: false, + Context: map[string]interface{}{ + "schema": "public", + "table": "very_long_table_name_that_exceeds_limits", + "length": 44, + "max_length": 32, + }, + }, + }, + GeneratedAt: time.Now(), + Database: "testdb", + SourceFormat: "postgresql", + } +} + +func TestNewMarkdownFormatter(t *testing.T) { + var buf bytes.Buffer + formatter := NewMarkdownFormatter(&buf) + + if formatter == nil { + t.Fatal("NewMarkdownFormatter() returned nil") + } + + // Buffer is not a terminal, so colors should be disabled + if formatter.UseColors { + t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal") + } +} + +func TestNewJSONFormatter(t *testing.T) { + formatter := NewJSONFormatter() + + if formatter == nil { + t.Fatal("NewJSONFormatter() returned nil") + } +} + +func TestMarkdownFormatter_Format(t *testing.T) { + report := createTestReport() + var buf bytes.Buffer + formatter := NewMarkdownFormatter(&buf) + + output, err := formatter.Format(report) + if err != nil { + t.Fatalf("MarkdownFormatter.Format() returned error: %v", err) + } + + // Check that output contains expected sections + if !strings.Contains(output, "# RelSpec Inspector Report") { + t.Error("Markdown output missing header") + } + + if !strings.Contains(output, "Database:") { + t.Error("Markdown output missing database field") + } + + if !strings.Contains(output, "testdb") { + t.Error("Markdown output missing database name") + } + + if !strings.Contains(output, "Summary") { + t.Error("Markdown output missing summary section") + } + + if !strings.Contains(output, "Rules Checked: 8") { + t.Error("Markdown output missing rules checked count") + } + + if !strings.Contains(output, "Errors: 3") { + t.Error("Markdown output missing error count") + } + + if !strings.Contains(output, "Warnings: 5") { + t.Error("Markdown output missing warning count") + } + + if !strings.Contains(output, "Violations") { + t.Error("Markdown output missing violations section") + } + + if !strings.Contains(output, "primary_key_naming") { + t.Error("Markdown output missing rule name") + } + + if !strings.Contains(output, "public.users.user_id") { + t.Error("Markdown output missing location") + } +} + +func TestMarkdownFormatter_FormatNoViolations(t *testing.T) { + report := &InspectorReport{ + Summary: ReportSummary{ + TotalRules: 10, + RulesChecked: 10, + RulesSkipped: 0, + ErrorCount: 0, + WarningCount: 0, + PassedCount: 50, + }, + Violations: []ValidationResult{}, + GeneratedAt: time.Now(), + Database: "testdb", + SourceFormat: "postgresql", + } + + var buf bytes.Buffer + formatter := NewMarkdownFormatter(&buf) + + output, err := formatter.Format(report) + if err != nil { + t.Fatalf("MarkdownFormatter.Format() returned error: %v", err) + } + + if !strings.Contains(output, "No violations found") { + t.Error("Markdown output should indicate no violations") + } +} + +func TestJSONFormatter_Format(t *testing.T) { + report := createTestReport() + formatter := NewJSONFormatter() + + output, err := formatter.Format(report) + if err != nil { + t.Fatalf("JSONFormatter.Format() returned error: %v", err) + } + + // Verify it's valid JSON + var decoded InspectorReport + if err := json.Unmarshal([]byte(output), &decoded); err != nil { + t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err) + } + + // Check key fields + if decoded.Database != "testdb" { + t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database) + } + + if decoded.Summary.ErrorCount != 3 { + t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount) + } + + if len(decoded.Violations) != 2 { + t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations)) + } +} + +func TestMarkdownFormatter_FormatHeader(t *testing.T) { + var buf bytes.Buffer + formatter := NewMarkdownFormatter(&buf) + + header := formatter.formatHeader("Test Header") + + if !strings.Contains(header, "# Test Header") { + t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header) + } +} + +func TestMarkdownFormatter_FormatBold(t *testing.T) { + tests := []struct { + name string + useColors bool + text string + wantContains string + }{ + { + name: "without colors", + useColors: false, + text: "Bold Text", + wantContains: "**Bold Text**", + }, + { + name: "with colors", + useColors: true, + text: "Bold Text", + wantContains: "Bold Text", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatter := &MarkdownFormatter{UseColors: tt.useColors} + result := formatter.formatBold(tt.text) + + if !strings.Contains(result, tt.wantContains) { + t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains) + } + }) + } +} + +func TestMarkdownFormatter_Colorize(t *testing.T) { + tests := []struct { + name string + useColors bool + text string + color string + wantColor bool + }{ + { + name: "without colors", + useColors: false, + text: "Test", + color: colorRed, + wantColor: false, + }, + { + name: "with colors", + useColors: true, + text: "Test", + color: colorRed, + wantColor: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatter := &MarkdownFormatter{UseColors: tt.useColors} + result := formatter.colorize(tt.text, tt.color) + + hasColor := strings.Contains(result, tt.color) + if hasColor != tt.wantColor { + t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor) + } + + if !strings.Contains(result, tt.text) { + t.Errorf("colorize() doesn't contain original text %q", tt.text) + } + }) + } +} + +func TestMarkdownFormatter_FormatContext(t *testing.T) { + formatter := &MarkdownFormatter{UseColors: false} + + context := map[string]interface{}{ + "schema": "public", + "table": "users", + "column": "id", + "pattern": "^id_", + "max_length": 64, + } + + result := formatter.formatContext(context) + + // Should not include schema, table, column (they're in location) + if strings.Contains(result, "schema") { + t.Error("formatContext() should skip schema field") + } + + if strings.Contains(result, "table=") { + t.Error("formatContext() should skip table field") + } + + if strings.Contains(result, "column=") { + t.Error("formatContext() should skip column field") + } + + // Should include other fields + if !strings.Contains(result, "pattern") { + t.Error("formatContext() should include pattern field") + } + + if !strings.Contains(result, "max_length") { + t.Error("formatContext() should include max_length field") + } +} + +func TestMarkdownFormatter_FormatViolation(t *testing.T) { + formatter := &MarkdownFormatter{UseColors: false} + + violation := ValidationResult{ + RuleName: "test_rule", + Level: "error", + Message: "Test violation message", + Location: "public.users.id", + Passed: false, + Context: map[string]interface{}{ + "pattern": "^id_", + }, + } + + result := formatter.formatViolation(violation, colorRed) + + if !strings.Contains(result, "test_rule") { + t.Error("formatViolation() should include rule name") + } + + if !strings.Contains(result, "Test violation message") { + t.Error("formatViolation() should include message") + } + + if !strings.Contains(result, "public.users.id") { + t.Error("formatViolation() should include location") + } + + if !strings.Contains(result, "Location:") { + t.Error("formatViolation() should include Location label") + } + + if !strings.Contains(result, "Message:") { + t.Error("formatViolation() should include Message label") + } +} + +func TestReportFormatConstants(t *testing.T) { + // Test that color constants are defined + if colorReset == "" { + t.Error("colorReset is not defined") + } + + if colorRed == "" { + t.Error("colorRed is not defined") + } + + if colorYellow == "" { + t.Error("colorYellow is not defined") + } + + if colorGreen == "" { + t.Error("colorGreen is not defined") + } + + if colorBold == "" { + t.Error("colorBold is not defined") + } +} diff --git a/pkg/inspector/rules_test.go b/pkg/inspector/rules_test.go new file mode 100644 index 0000000..efdc66a --- /dev/null +++ b/pkg/inspector/rules_test.go @@ -0,0 +1,249 @@ +package inspector + +import ( + "os" + "path/filepath" + "testing" +) + +func TestGetDefaultConfig(t *testing.T) { + config := GetDefaultConfig() + + if config == nil { + t.Fatal("GetDefaultConfig() returned nil") + } + + if config.Version != "1.0" { + t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version) + } + + if len(config.Rules) == 0 { + t.Error("GetDefaultConfig() returned no rules") + } + + // Check that all expected rules are present + expectedRules := []string{ + "primary_key_naming", + "primary_key_datatype", + "primary_key_auto_increment", + "foreign_key_column_naming", + "foreign_key_constraint_naming", + "foreign_key_index", + "table_naming_case", + "column_naming_case", + "table_name_length", + "column_name_length", + "reserved_keywords", + "missing_primary_key", + "orphaned_foreign_key", + "circular_dependency", + } + + for _, ruleName := range expectedRules { + if _, exists := config.Rules[ruleName]; !exists { + t.Errorf("GetDefaultConfig() missing rule: %q", ruleName) + } + } +} + +func TestLoadConfig_NonExistentFile(t *testing.T) { + // Try to load a non-existent file + config, err := LoadConfig("/path/to/nonexistent/file.yaml") + + if err != nil { + t.Fatalf("LoadConfig() with non-existent file returned error: %v", err) + } + + // Should return default config + if config == nil { + t.Fatal("LoadConfig() returned nil config for non-existent file") + } + + if len(config.Rules) == 0 { + t.Error("LoadConfig() returned config with no rules") + } +} + +func TestLoadConfig_ValidFile(t *testing.T) { + // Create a temporary config file + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "test-config.yaml") + + configContent := `version: "1.0" +rules: + primary_key_naming: + enabled: "enforce" + function: "primary_key_naming" + pattern: "^pk_" + message: "Primary keys must start with pk_" + table_name_length: + enabled: "warn" + function: "table_name_length" + max_length: 50 + message: "Table name too long" +` + + err := os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to create test config file: %v", err) + } + + config, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() returned error: %v", err) + } + + if config.Version != "1.0" { + t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version) + } + + if len(config.Rules) != 2 { + t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules)) + } + + // Check primary_key_naming rule + pkRule, exists := config.Rules["primary_key_naming"] + if !exists { + t.Fatal("LoadConfig() missing primary_key_naming rule") + } + + if pkRule.Enabled != "enforce" { + t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled) + } + + if pkRule.Pattern != "^pk_" { + t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern) + } + + // Check table_name_length rule + lengthRule, exists := config.Rules["table_name_length"] + if !exists { + t.Fatal("LoadConfig() missing table_name_length rule") + } + + if lengthRule.MaxLength != 50 { + t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength) + } +} + +func TestLoadConfig_InvalidYAML(t *testing.T) { + // Create a temporary invalid config file + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "invalid-config.yaml") + + invalidContent := `invalid: yaml: content: {[}]` + + err := os.WriteFile(configPath, []byte(invalidContent), 0644) + if err != nil { + t.Fatalf("Failed to create test config file: %v", err) + } + + _, err = LoadConfig(configPath) + if err == nil { + t.Error("LoadConfig() with invalid YAML did not return error") + } +} + +func TestRuleIsEnabled(t *testing.T) { + tests := []struct { + name string + rule Rule + want bool + }{ + { + name: "enforce is enabled", + rule: Rule{Enabled: "enforce"}, + want: true, + }, + { + name: "warn is enabled", + rule: Rule{Enabled: "warn"}, + want: true, + }, + { + name: "off is not enabled", + rule: Rule{Enabled: "off"}, + want: false, + }, + { + name: "empty is not enabled", + rule: Rule{Enabled: ""}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.rule.IsEnabled(); got != tt.want { + t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRuleIsEnforced(t *testing.T) { + tests := []struct { + name string + rule Rule + want bool + }{ + { + name: "enforce is enforced", + rule: Rule{Enabled: "enforce"}, + want: true, + }, + { + name: "warn is not enforced", + rule: Rule{Enabled: "warn"}, + want: false, + }, + { + name: "off is not enforced", + rule: Rule{Enabled: "off"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.rule.IsEnforced(); got != tt.want { + t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDefaultConfigRuleSettings(t *testing.T) { + config := GetDefaultConfig() + + // Test specific rule settings + pkNamingRule := config.Rules["primary_key_naming"] + if pkNamingRule.Function != "primary_key_naming" { + t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function) + } + + if pkNamingRule.Pattern != "^id_" { + t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern) + } + + // Test datatype rule + pkDatatypeRule := config.Rules["primary_key_datatype"] + if len(pkDatatypeRule.AllowedTypes) == 0 { + t.Error("primary_key_datatype has no allowed types") + } + + // Test length rule + tableLengthRule := config.Rules["table_name_length"] + if tableLengthRule.MaxLength != 64 { + t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength) + } + + // Test reserved keywords rule + reservedRule := config.Rules["reserved_keywords"] + if !reservedRule.CheckTables { + t.Error("reserved_keywords.CheckTables should be true") + } + if !reservedRule.CheckColumns { + t.Error("reserved_keywords.CheckColumns should be true") + } +} diff --git a/pkg/inspector/validators_test.go b/pkg/inspector/validators_test.go new file mode 100644 index 0000000..a80737d --- /dev/null +++ b/pkg/inspector/validators_test.go @@ -0,0 +1,837 @@ +package inspector + +import ( + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" +) + +// Helper function to create test database +func createTestDatabase() *models.Database { + return &models.Database{ + Name: "testdb", + Schemas: []*models.Schema{ + { + Name: "public", + Tables: []*models.Table{ + { + Name: "users", + Columns: map[string]*models.Column{ + "id": { + Name: "id", + Type: "bigserial", + IsPrimaryKey: true, + AutoIncrement: true, + }, + "username": { + Name: "username", + Type: "varchar(50)", + NotNull: true, + IsPrimaryKey: false, + }, + "rid_organization": { + Name: "rid_organization", + Type: "bigint", + NotNull: true, + IsPrimaryKey: false, + }, + }, + Constraints: map[string]*models.Constraint{ + "fk_users_organization": { + Name: "fk_users_organization", + Type: models.ForeignKeyConstraint, + Columns: []string{"rid_organization"}, + ReferencedTable: "organizations", + ReferencedSchema: "public", + ReferencedColumns: []string{"id"}, + }, + }, + Indexes: map[string]*models.Index{ + "idx_rid_organization": { + Name: "idx_rid_organization", + Columns: []string{"rid_organization"}, + }, + }, + }, + { + Name: "organizations", + Columns: map[string]*models.Column{ + "id": { + Name: "id", + Type: "bigserial", + IsPrimaryKey: true, + AutoIncrement: true, + }, + "name": { + Name: "name", + Type: "varchar(100)", + NotNull: true, + IsPrimaryKey: false, + }, + }, + }, + }, + }, + }, + } +} + +func TestValidatePrimaryKeyNaming(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "matching pattern id", + rule: Rule{ + Pattern: "^id$", + Message: "Primary key should be 'id'", + }, + wantLen: 2, + wantPass: true, + }, + { + name: "non-matching pattern id_", + rule: Rule{ + Pattern: "^id_", + Message: "Primary key should start with 'id_'", + }, + wantLen: 2, + wantPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validatePrimaryKeyNaming(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen) + } + if len(results) > 0 && results[0].Passed != tt.wantPass { + t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass) + } + }) + } +} + +func TestValidatePrimaryKeyDatatype(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "allowed type bigserial", + rule: Rule{ + AllowedTypes: []string{"bigserial", "bigint", "int"}, + Message: "Primary key should use integer types", + }, + wantLen: 2, + wantPass: true, + }, + { + name: "disallowed type", + rule: Rule{ + AllowedTypes: []string{"uuid"}, + Message: "Primary key should use UUID", + }, + wantLen: 2, + wantPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen) + } + if len(results) > 0 && results[0].Passed != tt.wantPass { + t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass) + } + }) + } +} + +func TestValidatePrimaryKeyAutoIncrement(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + }{ + { + name: "require auto increment", + rule: Rule{ + RequireAutoIncrement: true, + Message: "Primary key should have auto-increment", + }, + wantLen: 0, // No violations - all PKs have auto-increment + }, + { + name: "disallow auto increment", + rule: Rule{ + RequireAutoIncrement: false, + Message: "Primary key should not have auto-increment", + }, + wantLen: 2, // 2 violations - both PKs have auto-increment + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen) + } + }) + } +} + +func TestValidateForeignKeyColumnNaming(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "matching pattern rid_", + rule: Rule{ + Pattern: "^rid_", + Message: "Foreign key columns should start with 'rid_'", + }, + wantLen: 1, + wantPass: true, + }, + { + name: "non-matching pattern fk_", + rule: Rule{ + Pattern: "^fk_", + Message: "Foreign key columns should start with 'fk_'", + }, + wantLen: 1, + wantPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen) + } + if len(results) > 0 && results[0].Passed != tt.wantPass { + t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass) + } + }) + } +} + +func TestValidateForeignKeyConstraintNaming(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "matching pattern fk_", + rule: Rule{ + Pattern: "^fk_", + Message: "Foreign key constraints should start with 'fk_'", + }, + wantLen: 1, + wantPass: true, + }, + { + name: "non-matching pattern FK_", + rule: Rule{ + Pattern: "^FK_", + Message: "Foreign key constraints should start with 'FK_'", + }, + wantLen: 1, + wantPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen) + } + if len(results) > 0 && results[0].Passed != tt.wantPass { + t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass) + } + }) + } +} + +func TestValidateForeignKeyIndex(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "require index with index present", + rule: Rule{ + RequireIndex: true, + Message: "Foreign key columns should have indexes", + }, + wantLen: 1, + wantPass: true, + }, + { + name: "no requirement", + rule: Rule{ + RequireIndex: false, + Message: "Foreign key index check disabled", + }, + wantLen: 0, + wantPass: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateForeignKeyIndex(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen) + } + if len(results) > 0 && results[0].Passed != tt.wantPass { + t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass) + } + }) + } +} + +func TestValidateTableNamingCase(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "lowercase snake_case pattern", + rule: Rule{ + Pattern: "^[a-z][a-z0-9_]*$", + Case: "lowercase", + Message: "Table names should be lowercase snake_case", + }, + wantLen: 2, + wantPass: true, + }, + { + name: "uppercase pattern", + rule: Rule{ + Pattern: "^[A-Z][A-Z0-9_]*$", + Case: "uppercase", + Message: "Table names should be uppercase", + }, + wantLen: 2, + wantPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateTableNamingCase(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen) + } + if len(results) > 0 && results[0].Passed != tt.wantPass { + t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass) + } + }) + } +} + +func TestValidateColumnNamingCase(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "lowercase snake_case pattern", + rule: Rule{ + Pattern: "^[a-z][a-z0-9_]*$", + Case: "lowercase", + Message: "Column names should be lowercase snake_case", + }, + wantLen: 5, // 5 total columns across both tables + wantPass: true, + }, + { + name: "camelCase pattern", + rule: Rule{ + Pattern: "^[a-z][a-zA-Z0-9]*$", + Case: "camelCase", + Message: "Column names should be camelCase", + }, + wantLen: 5, + wantPass: false, // rid_organization has underscore + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateColumnNamingCase(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen) + } + }) + } +} + +func TestValidateTableNameLength(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "max length 64", + rule: Rule{ + MaxLength: 64, + Message: "Table name too long", + }, + wantLen: 2, + wantPass: true, + }, + { + name: "max length 5", + rule: Rule{ + MaxLength: 5, + Message: "Table name too long", + }, + wantLen: 2, + wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateTableNameLength(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen) + } + }) + } +} + +func TestValidateColumnNameLength(t *testing.T) { + db := createTestDatabase() + + tests := []struct { + name string + rule Rule + wantLen int + wantPass bool + }{ + { + name: "max length 64", + rule: Rule{ + MaxLength: 64, + Message: "Column name too long", + }, + wantLen: 5, + wantPass: true, + }, + { + name: "max length 5", + rule: Rule{ + MaxLength: 5, + Message: "Column name too long", + }, + wantLen: 5, + wantPass: false, // Some columns exceed 5 chars + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateColumnNameLength(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen) + } + }) + } +} + +func TestValidateReservedKeywords(t *testing.T) { + // Create a database with reserved keywords + db := &models.Database{ + Name: "testdb", + Schemas: []*models.Schema{ + { + Name: "public", + Tables: []*models.Table{ + { + Name: "user", // "user" is a reserved keyword + Columns: map[string]*models.Column{ + "id": { + Name: "id", + Type: "bigint", + IsPrimaryKey: true, + }, + "select": { // "select" is a reserved keyword + Name: "select", + Type: "varchar(50)", + }, + }, + }, + }, + }, + }, + } + + tests := []struct { + name string + rule Rule + wantLen int + checkPasses bool + }{ + { + name: "check tables only", + rule: Rule{ + CheckTables: true, + CheckColumns: false, + Message: "Reserved keyword used", + }, + wantLen: 1, // "user" table + checkPasses: false, + }, + { + name: "check columns only", + rule: Rule{ + CheckTables: false, + CheckColumns: true, + Message: "Reserved keyword used", + }, + wantLen: 2, // "id", "select" columns (id passes, select fails) + checkPasses: false, + }, + { + name: "check both", + rule: Rule{ + CheckTables: true, + CheckColumns: true, + Message: "Reserved keyword used", + }, + wantLen: 3, // "user" table + "id", "select" columns + checkPasses: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := validateReservedKeywords(db, tt.rule, "test_rule") + if len(results) != tt.wantLen { + t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen) + } + }) + } +} + +func TestValidateMissingPrimaryKey(t *testing.T) { + // Create database with and without primary keys + db := &models.Database{ + Name: "testdb", + Schemas: []*models.Schema{ + { + Name: "public", + Tables: []*models.Table{ + { + Name: "with_pk", + Columns: map[string]*models.Column{ + "id": { + Name: "id", + Type: "bigint", + IsPrimaryKey: true, + }, + }, + }, + { + Name: "without_pk", + Columns: map[string]*models.Column{ + "name": { + Name: "name", + Type: "varchar(50)", + }, + }, + }, + }, + }, + }, + } + + rule := Rule{ + Message: "Table missing primary key", + } + + results := validateMissingPrimaryKey(db, rule, "test_rule") + + if len(results) != 2 { + t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results)) + } + + // First result should pass (with_pk has PK) + if results[0].Passed != true { + t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed) + } + + // Second result should fail (without_pk missing PK) + if results[1].Passed != false { + t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed) + } +} + +func TestValidateOrphanedForeignKey(t *testing.T) { + // Create database with orphaned FK + db := &models.Database{ + Name: "testdb", + Schemas: []*models.Schema{ + { + Name: "public", + Tables: []*models.Table{ + { + Name: "users", + Columns: map[string]*models.Column{ + "id": { + Name: "id", + Type: "bigint", + IsPrimaryKey: true, + }, + }, + Constraints: map[string]*models.Constraint{ + "fk_nonexistent": { + Name: "fk_nonexistent", + Type: models.ForeignKeyConstraint, + Columns: []string{"rid_organization"}, + ReferencedTable: "nonexistent_table", + ReferencedSchema: "public", + }, + }, + }, + }, + }, + }, + } + + rule := Rule{ + Message: "Foreign key references non-existent table", + } + + results := validateOrphanedForeignKey(db, rule, "test_rule") + + if len(results) != 1 { + t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results)) + } + + if results[0].Passed != false { + t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed) + } +} + +func TestValidateCircularDependency(t *testing.T) { + // Create database with circular dependency + db := &models.Database{ + Name: "testdb", + Schemas: []*models.Schema{ + { + Name: "public", + Tables: []*models.Table{ + { + Name: "table_a", + Columns: map[string]*models.Column{ + "id": {Name: "id", Type: "bigint", IsPrimaryKey: true}, + }, + Constraints: map[string]*models.Constraint{ + "fk_to_b": { + Name: "fk_to_b", + Type: models.ForeignKeyConstraint, + ReferencedTable: "table_b", + ReferencedSchema: "public", + }, + }, + }, + { + Name: "table_b", + Columns: map[string]*models.Column{ + "id": {Name: "id", Type: "bigint", IsPrimaryKey: true}, + }, + Constraints: map[string]*models.Constraint{ + "fk_to_a": { + Name: "fk_to_a", + Type: models.ForeignKeyConstraint, + ReferencedTable: "table_a", + ReferencedSchema: "public", + }, + }, + }, + }, + }, + }, + } + + rule := Rule{ + Message: "Circular dependency detected", + } + + results := validateCircularDependency(db, rule, "test_rule") + + // Should detect circular dependency in both tables + if len(results) == 0 { + t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection") + } + + for _, result := range results { + if result.Passed { + t.Error("validateCircularDependency() passed=true, want false for circular dependency") + } + } +} + +func TestNormalizeDataType(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"varchar(50)", "varchar"}, + {"decimal(10,2)", "decimal"}, + {"int", "int"}, + {"BIGINT", "bigint"}, + {"VARCHAR(255)", "varchar"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeDataType(tt.input) + if result != tt.expected { + t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestContains(t *testing.T) { + tests := []struct { + name string + slice []string + value string + expected bool + }{ + {"found exact", []string{"foo", "bar", "baz"}, "bar", true}, + {"not found", []string{"foo", "bar", "baz"}, "qux", false}, + {"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true}, + {"empty slice", []string{}, "foo", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := contains(tt.slice, tt.value) + if result != tt.expected { + t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected) + } + }) + } +} + +func TestHasCycle(t *testing.T) { + tests := []struct { + name string + graph map[string][]string + node string + expected bool + }{ + { + name: "simple cycle", + graph: map[string][]string{ + "A": {"B"}, + "B": {"C"}, + "C": {"A"}, + }, + node: "A", + expected: true, + }, + { + name: "no cycle", + graph: map[string][]string{ + "A": {"B"}, + "B": {"C"}, + "C": {}, + }, + node: "A", + expected: false, + }, + { + name: "self cycle", + graph: map[string][]string{ + "A": {"A"}, + }, + node: "A", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + visited := make(map[string]bool) + recStack := make(map[string]bool) + result := hasCycle(tt.node, tt.graph, visited, recStack) + if result != tt.expected { + t.Errorf("hasCycle() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestFormatLocation(t *testing.T) { + tests := []struct { + schema string + table string + column string + expected string + }{ + {"public", "users", "id", "public.users.id"}, + {"public", "users", "", "public.users"}, + {"public", "", "", "public"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + result := formatLocation(tt.schema, tt.table, tt.column) + if result != tt.expected { + t.Errorf("formatLocation(%q, %q, %q) = %q, want %q", + tt.schema, tt.table, tt.column, result, tt.expected) + } + }) + } +} diff --git a/pkg/pgsql/datatypes_test.go b/pkg/pgsql/datatypes_test.go new file mode 100644 index 0000000..3ab8782 --- /dev/null +++ b/pkg/pgsql/datatypes_test.go @@ -0,0 +1,339 @@ +package pgsql + +import ( + "testing" +) + +func TestValidSQLType(t *testing.T) { + tests := []struct { + name string + sqltype string + want bool + }{ + // PostgreSQL types + {"Valid PGSQL bigint", "bigint", true}, + {"Valid PGSQL integer", "integer", true}, + {"Valid PGSQL text", "text", true}, + {"Valid PGSQL boolean", "boolean", true}, + {"Valid PGSQL double precision", "double precision", true}, + {"Valid PGSQL bytea", "bytea", true}, + {"Valid PGSQL uuid", "uuid", true}, + {"Valid PGSQL jsonb", "jsonb", true}, + {"Valid PGSQL json", "json", true}, + {"Valid PGSQL timestamp", "timestamp", true}, + {"Valid PGSQL date", "date", true}, + {"Valid PGSQL time", "time", true}, + {"Valid PGSQL citext", "citext", true}, + + // Standard types + {"Valid std double", "double", true}, + {"Valid std blob", "blob", true}, + + // Case insensitive + {"Case insensitive BIGINT", "BIGINT", true}, + {"Case insensitive TeXt", "TeXt", true}, + {"Case insensitive BoOlEaN", "BoOlEaN", true}, + + // Invalid types + {"Invalid type", "invalidtype", false}, + {"Invalid type varchar", "varchar", false}, + {"Empty string", "", false}, + {"Random string", "foobar", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidSQLType(tt.sqltype) + if got != tt.want { + t.Errorf("ValidSQLType(%q) = %v, want %v", tt.sqltype, got, tt.want) + } + }) + } +} + +func TestGetSQLType(t *testing.T) { + tests := []struct { + name string + anytype string + want string + }{ + // Go types to PostgreSQL types + {"Go bool to boolean", "bool", "boolean"}, + {"Go int64 to bigint", "int64", "bigint"}, + {"Go int to integer", "int", "integer"}, + {"Go string to text", "string", "text"}, + {"Go float64 to double precision", "float64", "double precision"}, + {"Go float32 to real", "float32", "real"}, + {"Go []byte to bytea", "[]byte", "bytea"}, + + // SQL types remain SQL types + {"SQL bigint", "bigint", "bigint"}, + {"SQL integer", "integer", "integer"}, + {"SQL text", "text", "text"}, + {"SQL boolean", "boolean", "boolean"}, + {"SQL uuid", "uuid", "uuid"}, + {"SQL jsonb", "jsonb", "jsonb"}, + + // Case insensitive Go types + {"Case insensitive BOOL", "BOOL", "boolean"}, + {"Case insensitive InT64", "InT64", "bigint"}, + {"Case insensitive STRING", "STRING", "text"}, + + // Case insensitive SQL types + {"Case insensitive BIGINT", "BIGINT", "bigint"}, + {"Case insensitive TEXT", "TEXT", "text"}, + + // Custom types + {"Custom sqluuid", "sqluuid", "uuid"}, + {"Custom sqljsonb", "sqljsonb", "jsonb"}, + {"Custom sqlint64", "sqlint64", "bigint"}, + + // Unknown types default to text + {"Unknown type varchar", "varchar", "text"}, + {"Unknown type foobar", "foobar", "text"}, + {"Empty string", "", "text"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetSQLType(tt.anytype) + if got != tt.want { + t.Errorf("GetSQLType(%q) = %q, want %q", tt.anytype, got, tt.want) + } + }) + } +} + +func TestConvertSQLType(t *testing.T) { + tests := []struct { + name string + anytype string + want string + }{ + // Go types to PostgreSQL types + {"Go bool to boolean", "bool", "boolean"}, + {"Go int64 to bigint", "int64", "bigint"}, + {"Go int to integer", "int", "integer"}, + {"Go string to text", "string", "text"}, + {"Go float64 to double precision", "float64", "double precision"}, + {"Go float32 to real", "float32", "real"}, + {"Go []byte to bytea", "[]byte", "bytea"}, + + // SQL types remain SQL types + {"SQL bigint", "bigint", "bigint"}, + {"SQL integer", "integer", "integer"}, + {"SQL text", "text", "text"}, + {"SQL boolean", "boolean", "boolean"}, + + // Case insensitive + {"Case insensitive BOOL", "BOOL", "boolean"}, + {"Case insensitive InT64", "InT64", "bigint"}, + + // Unknown types remain unchanged (difference from GetSQLType) + {"Unknown type varchar", "varchar", "varchar"}, + {"Unknown type foobar", "foobar", "foobar"}, + {"Empty string", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ConvertSQLType(tt.anytype) + if got != tt.want { + t.Errorf("ConvertSQLType(%q) = %q, want %q", tt.anytype, got, tt.want) + } + }) + } +} + +func TestIsGoType(t *testing.T) { + tests := []struct { + name string + typeName string + want bool + }{ + // Go basic types + {"Go bool", "bool", true}, + {"Go int64", "int64", true}, + {"Go int", "int", true}, + {"Go int32", "int32", true}, + {"Go int16", "int16", true}, + {"Go int8", "int8", true}, + {"Go uint", "uint", true}, + {"Go uint64", "uint64", true}, + {"Go uint32", "uint32", true}, + {"Go uint16", "uint16", true}, + {"Go uint8", "uint8", true}, + {"Go float64", "float64", true}, + {"Go float32", "float32", true}, + {"Go string", "string", true}, + {"Go []byte", "[]byte", true}, + + // Go custom types + {"Go complex64", "complex64", true}, + {"Go complex128", "complex128", true}, + {"Go uintptr", "uintptr", true}, + {"Go Pointer", "Pointer", true}, + + // Custom SQL types + {"Custom sqluuid", "sqluuid", true}, + {"Custom sqljsonb", "sqljsonb", true}, + {"Custom sqlint64", "sqlint64", true}, + {"Custom customdate", "customdate", true}, + {"Custom customtime", "customtime", true}, + + // Case insensitive + {"Case insensitive BOOL", "BOOL", true}, + {"Case insensitive InT64", "InT64", true}, + {"Case insensitive STRING", "STRING", true}, + + // SQL types (not Go types) + {"SQL bigint", "bigint", false}, + {"SQL integer", "integer", false}, + {"SQL text", "text", false}, + {"SQL boolean", "boolean", false}, + + // Invalid types + {"Invalid type", "invalidtype", false}, + {"Empty string", "", false}, + {"Random string", "foobar", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsGoType(tt.typeName) + if got != tt.want { + t.Errorf("IsGoType(%q) = %v, want %v", tt.typeName, got, tt.want) + } + }) + } +} + +func TestGetStdTypeFromGo(t *testing.T) { + tests := []struct { + name string + typeName string + want string + }{ + // Go types to standard SQL types + {"Go bool to boolean", "bool", "boolean"}, + {"Go int64 to bigint", "int64", "bigint"}, + {"Go int to integer", "int", "integer"}, + {"Go string to text", "string", "text"}, + {"Go float64 to double", "float64", "double"}, + {"Go float32 to double", "float32", "double"}, + {"Go []byte to blob", "[]byte", "blob"}, + {"Go int32 to integer", "int32", "integer"}, + {"Go int16 to smallint", "int16", "smallint"}, + + // Custom types + {"Custom sqluuid to uuid", "sqluuid", "uuid"}, + {"Custom sqljsonb to jsonb", "sqljsonb", "jsonb"}, + {"Custom sqlint64 to bigint", "sqlint64", "bigint"}, + {"Custom customdate to date", "customdate", "date"}, + + // Case insensitive + {"Case insensitive BOOL", "BOOL", "boolean"}, + {"Case insensitive InT64", "InT64", "bigint"}, + {"Case insensitive STRING", "STRING", "text"}, + + // Non-Go types remain unchanged + {"SQL bigint unchanged", "bigint", "bigint"}, + {"SQL integer unchanged", "integer", "integer"}, + {"Invalid type unchanged", "invalidtype", "invalidtype"}, + {"Empty string unchanged", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetStdTypeFromGo(tt.typeName) + if got != tt.want { + t.Errorf("GetStdTypeFromGo(%q) = %q, want %q", tt.typeName, got, tt.want) + } + }) + } +} + +func TestGoToStdTypesMap(t *testing.T) { + // Test that the map contains expected entries + expectedMappings := map[string]string{ + "bool": "boolean", + "int64": "bigint", + "int": "integer", + "string": "text", + "float64": "double", + "[]byte": "blob", + } + + for goType, expectedStd := range expectedMappings { + if stdType, ok := GoToStdTypes[goType]; !ok { + t.Errorf("GoToStdTypes missing entry for %q", goType) + } else if stdType != expectedStd { + t.Errorf("GoToStdTypes[%q] = %q, want %q", goType, stdType, expectedStd) + } + } + + // Test that the map is not empty + if len(GoToStdTypes) == 0 { + t.Error("GoToStdTypes map is empty") + } +} + +func TestGoToPGSQLTypesMap(t *testing.T) { + // Test that the map contains expected entries + expectedMappings := map[string]string{ + "bool": "boolean", + "int64": "bigint", + "int": "integer", + "string": "text", + "float64": "double precision", + "float32": "real", + "[]byte": "bytea", + } + + for goType, expectedPG := range expectedMappings { + if pgType, ok := GoToPGSQLTypes[goType]; !ok { + t.Errorf("GoToPGSQLTypes missing entry for %q", goType) + } else if pgType != expectedPG { + t.Errorf("GoToPGSQLTypes[%q] = %q, want %q", goType, pgType, expectedPG) + } + } + + // Test that the map is not empty + if len(GoToPGSQLTypes) == 0 { + t.Error("GoToPGSQLTypes map is empty") + } +} + +func TestTypeConversionConsistency(t *testing.T) { + // Test that GetSQLType and ConvertSQLType are consistent for known types + knownGoTypes := []string{"bool", "int64", "int", "string", "float64", "[]byte"} + + for _, goType := range knownGoTypes { + getSQLResult := GetSQLType(goType) + convertResult := ConvertSQLType(goType) + + if getSQLResult != convertResult { + t.Errorf("Inconsistent results for %q: GetSQLType=%q, ConvertSQLType=%q", + goType, getSQLResult, convertResult) + } + } +} + +func TestGetSQLTypeVsConvertSQLTypeDifference(t *testing.T) { + // Test that GetSQLType returns "text" for unknown types + // while ConvertSQLType returns the input unchanged + unknownTypes := []string{"varchar", "char", "customtype", "unknowntype"} + + for _, unknown := range unknownTypes { + getSQLResult := GetSQLType(unknown) + convertResult := ConvertSQLType(unknown) + + if getSQLResult != "text" { + t.Errorf("GetSQLType(%q) = %q, want %q", unknown, getSQLResult, "text") + } + + if convertResult != unknown { + t.Errorf("ConvertSQLType(%q) = %q, want %q", unknown, convertResult, unknown) + } + } +} diff --git a/pkg/pgsql/keywords_test.go b/pkg/pgsql/keywords_test.go new file mode 100644 index 0000000..93ebcd7 --- /dev/null +++ b/pkg/pgsql/keywords_test.go @@ -0,0 +1,136 @@ +package pgsql + +import ( + "testing" +) + +func TestGetPostgresKeywords(t *testing.T) { + keywords := GetPostgresKeywords() + + // Test that keywords are returned + if len(keywords) == 0 { + t.Fatal("Expected non-empty list of keywords") + } + + // Test that we get all keywords from the map + expectedCount := len(postgresKeywords) + if len(keywords) != expectedCount { + t.Errorf("Expected %d keywords, got %d", expectedCount, len(keywords)) + } + + // Test that all returned keywords exist in the map + for _, keyword := range keywords { + if !postgresKeywords[keyword] { + t.Errorf("Keyword %q not found in postgresKeywords map", keyword) + } + } + + // Test that no duplicate keywords are returned + seen := make(map[string]bool) + for _, keyword := range keywords { + if seen[keyword] { + t.Errorf("Duplicate keyword found: %q", keyword) + } + seen[keyword] = true + } +} + +func TestPostgresKeywordsMap(t *testing.T) { + tests := []struct { + name string + keyword string + want bool + }{ + {"SELECT keyword", "select", true}, + {"FROM keyword", "from", true}, + {"WHERE keyword", "where", true}, + {"TABLE keyword", "table", true}, + {"PRIMARY keyword", "primary", true}, + {"FOREIGN keyword", "foreign", true}, + {"CREATE keyword", "create", true}, + {"DROP keyword", "drop", true}, + {"ALTER keyword", "alter", true}, + {"INDEX keyword", "index", true}, + {"NOT keyword", "not", true}, + {"NULL keyword", "null", true}, + {"TRUE keyword", "true", true}, + {"FALSE keyword", "false", true}, + {"Non-keyword lowercase", "notakeyword", false}, + {"Non-keyword uppercase", "NOTAKEYWORD", false}, + {"Empty string", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := postgresKeywords[tt.keyword] + if got != tt.want { + t.Errorf("postgresKeywords[%q] = %v, want %v", tt.keyword, got, tt.want) + } + }) + } +} + +func TestPostgresKeywordsMapContent(t *testing.T) { + // Test that the map contains expected common keywords + commonKeywords := []string{ + "select", "insert", "update", "delete", "create", "drop", "alter", + "table", "index", "view", "schema", "function", "procedure", + "primary", "foreign", "key", "constraint", "unique", "check", + "null", "not", "and", "or", "like", "in", "between", + "join", "inner", "left", "right", "cross", "full", "outer", + "where", "having", "group", "order", "limit", "offset", + "union", "intersect", "except", + "begin", "commit", "rollback", "transaction", + } + + for _, keyword := range commonKeywords { + if !postgresKeywords[keyword] { + t.Errorf("Expected common keyword %q to be in postgresKeywords map", keyword) + } + } +} + +func TestPostgresKeywordsMapSize(t *testing.T) { + // PostgreSQL has a substantial list of reserved keywords + // This test ensures the map has a reasonable number of entries + minExpectedKeywords := 200 // PostgreSQL 13+ has 400+ reserved words + + if len(postgresKeywords) < minExpectedKeywords { + t.Errorf("Expected at least %d keywords, got %d. The map may be incomplete.", + minExpectedKeywords, len(postgresKeywords)) + } +} + +func TestGetPostgresKeywordsConsistency(t *testing.T) { + // Test that calling GetPostgresKeywords multiple times returns consistent results + keywords1 := GetPostgresKeywords() + keywords2 := GetPostgresKeywords() + + if len(keywords1) != len(keywords2) { + t.Errorf("Inconsistent results: first call returned %d keywords, second call returned %d", + len(keywords1), len(keywords2)) + } + + // Create a map from both results to compare + map1 := make(map[string]bool) + map2 := make(map[string]bool) + + for _, k := range keywords1 { + map1[k] = true + } + for _, k := range keywords2 { + map2[k] = true + } + + // Check that both contain the same keywords + for k := range map1 { + if !map2[k] { + t.Errorf("Keyword %q present in first call but not in second", k) + } + } + for k := range map2 { + if !map1[k] { + t.Errorf("Keyword %q present in second call but not in first", k) + } + } +} diff --git a/pkg/reflectutil/helpers_test.go b/pkg/reflectutil/helpers_test.go new file mode 100644 index 0000000..2e8b28e --- /dev/null +++ b/pkg/reflectutil/helpers_test.go @@ -0,0 +1,490 @@ +package reflectutil + +import ( + "reflect" + "testing" +) + +type testStruct struct { + Name string + Age int + Active bool + Nested *nestedStruct + Private string +} + +type nestedStruct struct { + Value string + Count int +} + +func TestDeref(t *testing.T) { + tests := []struct { + name string + input interface{} + wantValid bool + wantKind reflect.Kind + }{ + { + name: "non-pointer int", + input: 42, + wantValid: true, + wantKind: reflect.Int, + }, + { + name: "single pointer", + input: ptrInt(42), + wantValid: true, + wantKind: reflect.Int, + }, + { + name: "double pointer", + input: ptrPtr(ptrInt(42)), + wantValid: true, + wantKind: reflect.Int, + }, + { + name: "nil pointer", + input: (*int)(nil), + wantValid: false, + wantKind: reflect.Ptr, + }, + { + name: "string", + input: "test", + wantValid: true, + wantKind: reflect.String, + }, + { + name: "struct", + input: testStruct{Name: "test"}, + wantValid: true, + wantKind: reflect.Struct, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := reflect.ValueOf(tt.input) + got, valid := Deref(v) + + if valid != tt.wantValid { + t.Errorf("Deref() valid = %v, want %v", valid, tt.wantValid) + } + + if got.Kind() != tt.wantKind { + t.Errorf("Deref() kind = %v, want %v", got.Kind(), tt.wantKind) + } + }) + } +} + +func TestDerefInterface(t *testing.T) { + i := 42 + pi := &i + ppi := &pi + + tests := []struct { + name string + input interface{} + wantKind reflect.Kind + }{ + {"int", 42, reflect.Int}, + {"pointer to int", &i, reflect.Int}, + {"double pointer to int", ppi, reflect.Int}, + {"string", "test", reflect.String}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DerefInterface(tt.input) + if got.Kind() != tt.wantKind { + t.Errorf("DerefInterface() kind = %v, want %v", got.Kind(), tt.wantKind) + } + }) + } +} + +func TestGetFieldValue(t *testing.T) { + ts := testStruct{ + Name: "John", + Age: 30, + Active: true, + Nested: &nestedStruct{Value: "nested", Count: 5}, + } + + tests := []struct { + name string + item interface{} + field string + want interface{} + }{ + {"struct field Name", ts, "Name", "John"}, + {"struct field Age", ts, "Age", 30}, + {"struct field Active", ts, "Active", true}, + {"struct non-existent field", ts, "NonExistent", nil}, + {"pointer to struct", &ts, "Name", "John"}, + {"map string key", map[string]string{"key": "value"}, "key", "value"}, + {"map int key", map[string]int{"count": 42}, "count", 42}, + {"map non-existent key", map[string]string{"key": "value"}, "missing", nil}, + {"nil pointer", (*testStruct)(nil), "Name", nil}, + {"non-struct non-map", 42, "field", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetFieldValue(tt.item, tt.field) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetFieldValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsSliceOrArray(t *testing.T) { + arr := [3]int{1, 2, 3} + + tests := []struct { + name string + input interface{} + want bool + }{ + {"slice", []int{1, 2, 3}, true}, + {"array", arr, true}, + {"pointer to slice", &[]int{1, 2, 3}, true}, + {"string", "test", false}, + {"int", 42, false}, + {"map", map[string]int{}, false}, + {"nil slice", ([]int)(nil), true}, // nil slice is still Kind==Slice + {"nil pointer", (*[]int)(nil), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsSliceOrArray(tt.input) + if got != tt.want { + t.Errorf("IsSliceOrArray() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsMap(t *testing.T) { + tests := []struct { + name string + input interface{} + want bool + }{ + {"map[string]int", map[string]int{"a": 1}, true}, + {"map[int]string", map[int]string{1: "a"}, true}, + {"pointer to map", &map[string]int{"a": 1}, true}, + {"slice", []int{1, 2, 3}, false}, + {"string", "test", false}, + {"int", 42, false}, + {"nil map", (map[string]int)(nil), true}, // nil map is still Kind==Map + {"nil pointer", (*map[string]int)(nil), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsMap(tt.input) + if got != tt.want { + t.Errorf("IsMap() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSliceLen(t *testing.T) { + arr := [3]int{1, 2, 3} + + tests := []struct { + name string + input interface{} + want int + }{ + {"slice length 3", []int{1, 2, 3}, 3}, + {"empty slice", []int{}, 0}, + {"array length 3", arr, 3}, + {"pointer to slice", &[]int{1, 2, 3}, 3}, + {"not a slice", "test", 0}, + {"int", 42, 0}, + {"nil slice", ([]int)(nil), 0}, + {"nil pointer", (*[]int)(nil), 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SliceLen(tt.input) + if got != tt.want { + t.Errorf("SliceLen() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMapLen(t *testing.T) { + tests := []struct { + name string + input interface{} + want int + }{ + {"map length 2", map[string]int{"a": 1, "b": 2}, 2}, + {"empty map", map[string]int{}, 0}, + {"pointer to map", &map[string]int{"a": 1}, 1}, + {"not a map", []int{1, 2, 3}, 0}, + {"string", "test", 0}, + {"nil map", (map[string]int)(nil), 0}, + {"nil pointer", (*map[string]int)(nil), 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MapLen(tt.input) + if got != tt.want { + t.Errorf("MapLen() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSliceToInterfaces(t *testing.T) { + tests := []struct { + name string + input interface{} + want []interface{} + }{ + {"int slice", []int{1, 2, 3}, []interface{}{1, 2, 3}}, + {"string slice", []string{"a", "b"}, []interface{}{"a", "b"}}, + {"empty slice", []int{}, []interface{}{}}, + {"pointer to slice", &[]int{1, 2}, []interface{}{1, 2}}, + {"not a slice", "test", []interface{}{}}, + {"nil slice", ([]int)(nil), []interface{}{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SliceToInterfaces(tt.input) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SliceToInterfaces() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMapKeys(t *testing.T) { + tests := []struct { + name string + input interface{} + want []interface{} + }{ + {"map with keys", map[string]int{"a": 1, "b": 2}, []interface{}{"a", "b"}}, + {"empty map", map[string]int{}, []interface{}{}}, + {"not a map", []int{1, 2, 3}, []interface{}{}}, + {"nil map", (map[string]int)(nil), []interface{}{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MapKeys(tt.input) + if len(got) != len(tt.want) { + t.Errorf("MapKeys() length = %v, want %v", len(got), len(tt.want)) + } + // For maps, order is not guaranteed, so just check length + }) + } +} + +func TestMapValues(t *testing.T) { + tests := []struct { + name string + input interface{} + want int // length of values + }{ + {"map with values", map[string]int{"a": 1, "b": 2}, 2}, + {"empty map", map[string]int{}, 0}, + {"not a map", []int{1, 2, 3}, 0}, + {"nil map", (map[string]int)(nil), 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MapValues(tt.input) + if len(got) != tt.want { + t.Errorf("MapValues() length = %v, want %v", len(got), tt.want) + } + }) + } +} + +func TestMapGet(t *testing.T) { + m := map[string]int{"a": 1, "b": 2} + + tests := []struct { + name string + input interface{} + key interface{} + want interface{} + }{ + {"existing key", m, "a", 1}, + {"existing key b", m, "b", 2}, + {"non-existing key", m, "c", nil}, + {"pointer to map", &m, "a", 1}, + {"not a map", []int{1, 2}, 0, nil}, + {"nil map", (map[string]int)(nil), "a", nil}, + {"nil pointer", (*map[string]int)(nil), "a", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MapGet(tt.input, tt.key) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MapGet() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSliceIndex(t *testing.T) { + s := []int{10, 20, 30} + + tests := []struct { + name string + slice interface{} + index int + want interface{} + }{ + {"index 0", s, 0, 10}, + {"index 1", s, 1, 20}, + {"index 2", s, 2, 30}, + {"negative index", s, -1, nil}, + {"out of bounds", s, 5, nil}, + {"pointer to slice", &s, 1, 20}, + {"not a slice", "test", 0, nil}, + {"nil slice", ([]int)(nil), 0, nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SliceIndex(tt.slice, tt.index) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SliceIndex() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCompareValues(t *testing.T) { + tests := []struct { + name string + a interface{} + b interface{} + want int + }{ + {"both nil", nil, nil, 0}, + {"a nil", nil, 5, -1}, + {"b nil", 5, nil, 1}, + {"equal strings", "abc", "abc", 0}, + {"a less than b strings", "abc", "xyz", -1}, + {"a greater than b strings", "xyz", "abc", 1}, + {"equal ints", 5, 5, 0}, + {"a less than b ints", 3, 7, -1}, + {"a greater than b ints", 10, 5, 1}, + {"equal floats", 3.14, 3.14, 0}, + {"a less than b floats", 2.5, 5.5, -1}, + {"a greater than b floats", 10.5, 5.5, 1}, + {"equal uints", uint(5), uint(5), 0}, + {"different types", "abc", 123, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CompareValues(tt.a, tt.b) + if got != tt.want { + t.Errorf("CompareValues(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func TestGetNestedValue(t *testing.T) { + nested := map[string]interface{}{ + "level1": map[string]interface{}{ + "level2": map[string]interface{}{ + "value": "deep", + }, + }, + } + + ts := testStruct{ + Name: "John", + Nested: &nestedStruct{ + Value: "nested value", + Count: 42, + }, + } + + tests := []struct { + name string + input interface{} + path string + want interface{} + }{ + {"empty path", nested, "", nested}, + {"single level map", nested, "level1", nested["level1"]}, + {"nested map", nested, "level1.level2", map[string]interface{}{"value": "deep"}}, + {"deep nested map", nested, "level1.level2.value", "deep"}, + {"struct field", ts, "Name", "John"}, + {"nested struct field", ts, "Nested", ts.Nested}, + {"non-existent path", nested, "missing.path", nil}, + {"nil input", nil, "path", nil}, + {"partial missing path", nested, "level1.missing", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetNestedValue(tt.input, tt.path) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetNestedValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDeepEqual(t *testing.T) { + tests := []struct { + name string + a interface{} + b interface{} + want bool + }{ + {"equal ints", 42, 42, true}, + {"different ints", 42, 43, false}, + {"equal strings", "test", "test", true}, + {"different strings", "test", "other", false}, + {"equal slices", []int{1, 2, 3}, []int{1, 2, 3}, true}, + {"different slices", []int{1, 2, 3}, []int{1, 2, 4}, false}, + {"equal maps", map[string]int{"a": 1}, map[string]int{"a": 1}, true}, + {"different maps", map[string]int{"a": 1}, map[string]int{"a": 2}, false}, + {"both nil", nil, nil, true}, + {"one nil", nil, 42, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DeepEqual(tt.a, tt.b) + if got != tt.want { + t.Errorf("DeepEqual(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} + +// Helper functions +func ptrInt(i int) *int { + return &i +} + +func ptrPtr(p *int) **int { + return &p +}