test(pgsql, reflectutil): ✨ add comprehensive test coverage
All checks were successful
All checks were successful
* Introduce tests for PostgreSQL data types and keywords. * Implement tests for reflect utility functions. * Ensure consistency and correctness of type conversions and keyword mappings. * Validate behavior for various edge cases and input types.
This commit is contained in:
238
pkg/inspector/inspector_test.go
Normal file
238
pkg/inspector/inspector_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewInspector(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
if inspector == nil {
|
||||
t.Fatal("NewInspector() returned nil")
|
||||
}
|
||||
|
||||
if inspector.db != db {
|
||||
t.Error("NewInspector() database not set correctly")
|
||||
}
|
||||
|
||||
if inspector.config != config {
|
||||
t.Error("NewInspector() config not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspect(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report == nil {
|
||||
t.Fatal("Inspect() returned nil report")
|
||||
}
|
||||
|
||||
if report.Database != db.Name {
|
||||
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
|
||||
}
|
||||
|
||||
if report.Summary.TotalRules != len(config.Rules) {
|
||||
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
|
||||
}
|
||||
|
||||
if len(report.Violations) == 0 {
|
||||
t.Error("Inspect() returned no violations, expected some results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithDisabledRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Disable all rules
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 0 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
if report.Summary.RulesSkipped != len(config.Rules) {
|
||||
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectWithEnforcedRules(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Enable only one rule and enforce it
|
||||
for name := range config.Rules {
|
||||
rule := config.Rules[name]
|
||||
rule.Enabled = "off"
|
||||
config.Rules[name] = rule
|
||||
}
|
||||
|
||||
primaryKeyRule := config.Rules["primary_key_naming"]
|
||||
primaryKeyRule.Enabled = "enforce"
|
||||
primaryKeyRule.Pattern = "^id$"
|
||||
config.Rules["primary_key_naming"] = primaryKeyRule
|
||||
|
||||
inspector := NewInspector(db, config)
|
||||
report, err := inspector.Inspect()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Inspect() returned error: %v", err)
|
||||
}
|
||||
|
||||
if report.Summary.RulesChecked != 1 {
|
||||
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
|
||||
}
|
||||
|
||||
// All results should be at error level for enforced rules
|
||||
for _, violation := range report.Violations {
|
||||
if violation.Level != "error" {
|
||||
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSummary(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
config := GetDefaultConfig()
|
||||
inspector := NewInspector(db, config)
|
||||
|
||||
results := []ValidationResult{
|
||||
{RuleName: "rule1", Passed: true, Level: "error"},
|
||||
{RuleName: "rule2", Passed: false, Level: "error"},
|
||||
{RuleName: "rule3", Passed: false, Level: "warning"},
|
||||
{RuleName: "rule4", Passed: true, Level: "warning"},
|
||||
}
|
||||
|
||||
summary := inspector.generateSummary(results)
|
||||
|
||||
if summary.PassedCount != 2 {
|
||||
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
|
||||
}
|
||||
|
||||
if summary.ErrorCount != 1 {
|
||||
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
|
||||
}
|
||||
|
||||
if summary.WarningCount != 1 {
|
||||
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report *InspectorReport
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "with errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 5,
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "without errors",
|
||||
report: &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
ErrorCount: 0,
|
||||
WarningCount: 3,
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.report.HasErrors(); got != tt.want {
|
||||
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
functionName string
|
||||
wantExists bool
|
||||
}{
|
||||
{"primary_key_naming", "primary_key_naming", true},
|
||||
{"primary_key_datatype", "primary_key_datatype", true},
|
||||
{"foreign_key_column_naming", "foreign_key_column_naming", true},
|
||||
{"table_regexpr", "table_regexpr", true},
|
||||
{"column_regexpr", "column_regexpr", true},
|
||||
{"reserved_words", "reserved_words", true},
|
||||
{"have_primary_key", "have_primary_key", true},
|
||||
{"orphaned_foreign_key", "orphaned_foreign_key", true},
|
||||
{"circular_dependency", "circular_dependency", true},
|
||||
{"unknown_function", "unknown_function", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, exists := getValidator(tt.functionName)
|
||||
if exists != tt.wantExists {
|
||||
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateResult(t *testing.T) {
|
||||
result := createResult(
|
||||
"test_rule",
|
||||
true,
|
||||
"Test message",
|
||||
"schema.table.column",
|
||||
map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
},
|
||||
)
|
||||
|
||||
if result.RuleName != "test_rule" {
|
||||
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
|
||||
}
|
||||
|
||||
if !result.Passed {
|
||||
t.Error("createResult() Passed = false, want true")
|
||||
}
|
||||
|
||||
if result.Message != "Test message" {
|
||||
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
|
||||
}
|
||||
|
||||
if result.Location != "schema.table.column" {
|
||||
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
|
||||
}
|
||||
|
||||
if len(result.Context) != 2 {
|
||||
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
|
||||
}
|
||||
}
|
||||
366
pkg/inspector/report_test.go
Normal file
366
pkg/inspector/report_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTestReport() *InspectorReport {
|
||||
return &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 8,
|
||||
RulesSkipped: 2,
|
||||
ErrorCount: 3,
|
||||
WarningCount: 5,
|
||||
PassedCount: 12,
|
||||
},
|
||||
Violations: []ValidationResult{
|
||||
{
|
||||
RuleName: "primary_key_naming",
|
||||
Level: "error",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
Location: "public.users.user_id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "user_id",
|
||||
"pattern": "^id_",
|
||||
},
|
||||
},
|
||||
{
|
||||
RuleName: "table_name_length",
|
||||
Level: "warning",
|
||||
Message: "Table name too long",
|
||||
Location: "public.very_long_table_name_that_exceeds_limits",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "very_long_table_name_that_exceeds_limits",
|
||||
"length": 44,
|
||||
"max_length": 32,
|
||||
},
|
||||
},
|
||||
},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMarkdownFormatter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewMarkdownFormatter() returned nil")
|
||||
}
|
||||
|
||||
// Buffer is not a terminal, so colors should be disabled
|
||||
if formatter.UseColors {
|
||||
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJSONFormatter(t *testing.T) {
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
if formatter == nil {
|
||||
t.Fatal("NewJSONFormatter() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Check that output contains expected sections
|
||||
if !strings.Contains(output, "# RelSpec Inspector Report") {
|
||||
t.Error("Markdown output missing header")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Database:") {
|
||||
t.Error("Markdown output missing database field")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "testdb") {
|
||||
t.Error("Markdown output missing database name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Summary") {
|
||||
t.Error("Markdown output missing summary section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Rules Checked: 8") {
|
||||
t.Error("Markdown output missing rules checked count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Errors: 3") {
|
||||
t.Error("Markdown output missing error count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Warnings: 5") {
|
||||
t.Error("Markdown output missing warning count")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Violations") {
|
||||
t.Error("Markdown output missing violations section")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "primary_key_naming") {
|
||||
t.Error("Markdown output missing rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "public.users.user_id") {
|
||||
t.Error("Markdown output missing location")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
|
||||
report := &InspectorReport{
|
||||
Summary: ReportSummary{
|
||||
TotalRules: 10,
|
||||
RulesChecked: 10,
|
||||
RulesSkipped: 0,
|
||||
ErrorCount: 0,
|
||||
WarningCount: 0,
|
||||
PassedCount: 50,
|
||||
},
|
||||
Violations: []ValidationResult{},
|
||||
GeneratedAt: time.Now(),
|
||||
Database: "testdb",
|
||||
SourceFormat: "postgresql",
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "No violations found") {
|
||||
t.Error("Markdown output should indicate no violations")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFormatter_Format(t *testing.T) {
|
||||
report := createTestReport()
|
||||
formatter := NewJSONFormatter()
|
||||
|
||||
output, err := formatter.Format(report)
|
||||
if err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON
|
||||
var decoded InspectorReport
|
||||
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
|
||||
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check key fields
|
||||
if decoded.Database != "testdb" {
|
||||
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
|
||||
}
|
||||
|
||||
if decoded.Summary.ErrorCount != 3 {
|
||||
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
|
||||
}
|
||||
|
||||
if len(decoded.Violations) != 2 {
|
||||
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
formatter := NewMarkdownFormatter(&buf)
|
||||
|
||||
header := formatter.formatHeader("Test Header")
|
||||
|
||||
if !strings.Contains(header, "# Test Header") {
|
||||
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatBold(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Bold Text",
|
||||
wantContains: "**Bold Text**",
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Bold Text",
|
||||
wantContains: "Bold Text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.formatBold(tt.text)
|
||||
|
||||
if !strings.Contains(result, tt.wantContains) {
|
||||
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_Colorize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
useColors bool
|
||||
text string
|
||||
color string
|
||||
wantColor bool
|
||||
}{
|
||||
{
|
||||
name: "without colors",
|
||||
useColors: false,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: false,
|
||||
},
|
||||
{
|
||||
name: "with colors",
|
||||
useColors: true,
|
||||
text: "Test",
|
||||
color: colorRed,
|
||||
wantColor: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: tt.useColors}
|
||||
result := formatter.colorize(tt.text, tt.color)
|
||||
|
||||
hasColor := strings.Contains(result, tt.color)
|
||||
if hasColor != tt.wantColor {
|
||||
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, tt.text) {
|
||||
t.Errorf("colorize() doesn't contain original text %q", tt.text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatContext(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
context := map[string]interface{}{
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"column": "id",
|
||||
"pattern": "^id_",
|
||||
"max_length": 64,
|
||||
}
|
||||
|
||||
result := formatter.formatContext(context)
|
||||
|
||||
// Should not include schema, table, column (they're in location)
|
||||
if strings.Contains(result, "schema") {
|
||||
t.Error("formatContext() should skip schema field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "table=") {
|
||||
t.Error("formatContext() should skip table field")
|
||||
}
|
||||
|
||||
if strings.Contains(result, "column=") {
|
||||
t.Error("formatContext() should skip column field")
|
||||
}
|
||||
|
||||
// Should include other fields
|
||||
if !strings.Contains(result, "pattern") {
|
||||
t.Error("formatContext() should include pattern field")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "max_length") {
|
||||
t.Error("formatContext() should include max_length field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
|
||||
formatter := &MarkdownFormatter{UseColors: false}
|
||||
|
||||
violation := ValidationResult{
|
||||
RuleName: "test_rule",
|
||||
Level: "error",
|
||||
Message: "Test violation message",
|
||||
Location: "public.users.id",
|
||||
Passed: false,
|
||||
Context: map[string]interface{}{
|
||||
"pattern": "^id_",
|
||||
},
|
||||
}
|
||||
|
||||
result := formatter.formatViolation(violation, colorRed)
|
||||
|
||||
if !strings.Contains(result, "test_rule") {
|
||||
t.Error("formatViolation() should include rule name")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Test violation message") {
|
||||
t.Error("formatViolation() should include message")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "public.users.id") {
|
||||
t.Error("formatViolation() should include location")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Location:") {
|
||||
t.Error("formatViolation() should include Location label")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Message:") {
|
||||
t.Error("formatViolation() should include Message label")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportFormatConstants(t *testing.T) {
|
||||
// Test that color constants are defined
|
||||
if colorReset == "" {
|
||||
t.Error("colorReset is not defined")
|
||||
}
|
||||
|
||||
if colorRed == "" {
|
||||
t.Error("colorRed is not defined")
|
||||
}
|
||||
|
||||
if colorYellow == "" {
|
||||
t.Error("colorYellow is not defined")
|
||||
}
|
||||
|
||||
if colorGreen == "" {
|
||||
t.Error("colorGreen is not defined")
|
||||
}
|
||||
|
||||
if colorBold == "" {
|
||||
t.Error("colorBold is not defined")
|
||||
}
|
||||
}
|
||||
249
pkg/inspector/rules_test.go
Normal file
249
pkg/inspector/rules_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDefaultConfig(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("GetDefaultConfig() returned nil")
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("GetDefaultConfig() returned no rules")
|
||||
}
|
||||
|
||||
// Check that all expected rules are present
|
||||
expectedRules := []string{
|
||||
"primary_key_naming",
|
||||
"primary_key_datatype",
|
||||
"primary_key_auto_increment",
|
||||
"foreign_key_column_naming",
|
||||
"foreign_key_constraint_naming",
|
||||
"foreign_key_index",
|
||||
"table_naming_case",
|
||||
"column_naming_case",
|
||||
"table_name_length",
|
||||
"column_name_length",
|
||||
"reserved_keywords",
|
||||
"missing_primary_key",
|
||||
"orphaned_foreign_key",
|
||||
"circular_dependency",
|
||||
}
|
||||
|
||||
for _, ruleName := range expectedRules {
|
||||
if _, exists := config.Rules[ruleName]; !exists {
|
||||
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_NonExistentFile(t *testing.T) {
|
||||
// Try to load a non-existent file
|
||||
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should return default config
|
||||
if config == nil {
|
||||
t.Fatal("LoadConfig() returned nil config for non-existent file")
|
||||
}
|
||||
|
||||
if len(config.Rules) == 0 {
|
||||
t.Error("LoadConfig() returned config with no rules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ValidFile(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test-config.yaml")
|
||||
|
||||
configContent := `version: "1.0"
|
||||
rules:
|
||||
primary_key_naming:
|
||||
enabled: "enforce"
|
||||
function: "primary_key_naming"
|
||||
pattern: "^pk_"
|
||||
message: "Primary keys must start with pk_"
|
||||
table_name_length:
|
||||
enabled: "warn"
|
||||
function: "table_name_length"
|
||||
max_length: 50
|
||||
message: "Table name too long"
|
||||
`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() returned error: %v", err)
|
||||
}
|
||||
|
||||
if config.Version != "1.0" {
|
||||
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
|
||||
}
|
||||
|
||||
if len(config.Rules) != 2 {
|
||||
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
|
||||
}
|
||||
|
||||
// Check primary_key_naming rule
|
||||
pkRule, exists := config.Rules["primary_key_naming"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing primary_key_naming rule")
|
||||
}
|
||||
|
||||
if pkRule.Enabled != "enforce" {
|
||||
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
|
||||
}
|
||||
|
||||
if pkRule.Pattern != "^pk_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
|
||||
}
|
||||
|
||||
// Check table_name_length rule
|
||||
lengthRule, exists := config.Rules["table_name_length"]
|
||||
if !exists {
|
||||
t.Fatal("LoadConfig() missing table_name_length rule")
|
||||
}
|
||||
|
||||
if lengthRule.MaxLength != 50 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_InvalidYAML(t *testing.T) {
|
||||
// Create a temporary invalid config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
|
||||
|
||||
invalidContent := `invalid: yaml: content: {[}]`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configPath)
|
||||
if err == nil {
|
||||
t.Error("LoadConfig() with invalid YAML did not return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enabled",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is enabled",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "off is not enabled",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty is not enabled",
|
||||
rule: Rule{Enabled: ""},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnabled(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleIsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "enforce is enforced",
|
||||
rule: Rule{Enabled: "enforce"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "warn is not enforced",
|
||||
rule: Rule{Enabled: "warn"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "off is not enforced",
|
||||
rule: Rule{Enabled: "off"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.rule.IsEnforced(); got != tt.want {
|
||||
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfigRuleSettings(t *testing.T) {
|
||||
config := GetDefaultConfig()
|
||||
|
||||
// Test specific rule settings
|
||||
pkNamingRule := config.Rules["primary_key_naming"]
|
||||
if pkNamingRule.Function != "primary_key_naming" {
|
||||
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
|
||||
}
|
||||
|
||||
if pkNamingRule.Pattern != "^id_" {
|
||||
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
|
||||
}
|
||||
|
||||
// Test datatype rule
|
||||
pkDatatypeRule := config.Rules["primary_key_datatype"]
|
||||
if len(pkDatatypeRule.AllowedTypes) == 0 {
|
||||
t.Error("primary_key_datatype has no allowed types")
|
||||
}
|
||||
|
||||
// Test length rule
|
||||
tableLengthRule := config.Rules["table_name_length"]
|
||||
if tableLengthRule.MaxLength != 64 {
|
||||
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
|
||||
}
|
||||
|
||||
// Test reserved keywords rule
|
||||
reservedRule := config.Rules["reserved_keywords"]
|
||||
if !reservedRule.CheckTables {
|
||||
t.Error("reserved_keywords.CheckTables should be true")
|
||||
}
|
||||
if !reservedRule.CheckColumns {
|
||||
t.Error("reserved_keywords.CheckColumns should be true")
|
||||
}
|
||||
}
|
||||
837
pkg/inspector/validators_test.go
Normal file
837
pkg/inspector/validators_test.go
Normal file
@@ -0,0 +1,837 @@
|
||||
package inspector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// Helper function to create test database
|
||||
func createTestDatabase() *models.Database {
|
||||
return &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"username": {
|
||||
Name: "username",
|
||||
Type: "varchar(50)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
"rid_organization": {
|
||||
Name: "rid_organization",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_users_organization": {
|
||||
Name: "fk_users_organization",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "organizations",
|
||||
ReferencedSchema: "public",
|
||||
ReferencedColumns: []string{"id"},
|
||||
},
|
||||
},
|
||||
Indexes: map[string]*models.Index{
|
||||
"idx_rid_organization": {
|
||||
Name: "idx_rid_organization",
|
||||
Columns: []string{"rid_organization"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "organizations",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(100)",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern id",
|
||||
rule: Rule{
|
||||
Pattern: "^id$",
|
||||
Message: "Primary key should be 'id'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern id_",
|
||||
rule: Rule{
|
||||
Pattern: "^id_",
|
||||
Message: "Primary key should start with 'id_'",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyDatatype(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "allowed type bigserial",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"bigserial", "bigint", "int"},
|
||||
Message: "Primary key should use integer types",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "disallowed type",
|
||||
rule: Rule{
|
||||
AllowedTypes: []string{"uuid"},
|
||||
Message: "Primary key should use UUID",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "require auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: true,
|
||||
Message: "Primary key should have auto-increment",
|
||||
},
|
||||
wantLen: 0, // No violations - all PKs have auto-increment
|
||||
},
|
||||
{
|
||||
name: "disallow auto increment",
|
||||
rule: Rule{
|
||||
RequireAutoIncrement: false,
|
||||
Message: "Primary key should not have auto-increment",
|
||||
},
|
||||
wantLen: 2, // 2 violations - both PKs have auto-increment
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyColumnNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern rid_",
|
||||
rule: Rule{
|
||||
Pattern: "^rid_",
|
||||
Message: "Foreign key columns should start with 'rid_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key columns should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "matching pattern fk_",
|
||||
rule: Rule{
|
||||
Pattern: "^fk_",
|
||||
Message: "Foreign key constraints should start with 'fk_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching pattern FK_",
|
||||
rule: Rule{
|
||||
Pattern: "^FK_",
|
||||
Message: "Foreign key constraints should start with 'FK_'",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateForeignKeyIndex(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "require index with index present",
|
||||
rule: Rule{
|
||||
RequireIndex: true,
|
||||
Message: "Foreign key columns should have indexes",
|
||||
},
|
||||
wantLen: 1,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "no requirement",
|
||||
rule: Rule{
|
||||
RequireIndex: false,
|
||||
Message: "Foreign key index check disabled",
|
||||
},
|
||||
wantLen: 0,
|
||||
wantPass: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Table names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[A-Z][A-Z0-9_]*$",
|
||||
Case: "uppercase",
|
||||
Message: "Table names should be uppercase",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
if len(results) > 0 && results[0].Passed != tt.wantPass {
|
||||
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNamingCase(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "lowercase snake_case pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-z0-9_]*$",
|
||||
Case: "lowercase",
|
||||
Message: "Column names should be lowercase snake_case",
|
||||
},
|
||||
wantLen: 5, // 5 total columns across both tables
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "camelCase pattern",
|
||||
rule: Rule{
|
||||
Pattern: "^[a-z][a-zA-Z0-9]*$",
|
||||
Case: "camelCase",
|
||||
Message: "Column names should be camelCase",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // rid_organization has underscore
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNamingCase(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTableNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Table name too long",
|
||||
},
|
||||
wantLen: 2,
|
||||
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateTableNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnNameLength(t *testing.T) {
|
||||
db := createTestDatabase()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "max length 64",
|
||||
rule: Rule{
|
||||
MaxLength: 64,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "max length 5",
|
||||
rule: Rule{
|
||||
MaxLength: 5,
|
||||
Message: "Column name too long",
|
||||
},
|
||||
wantLen: 5,
|
||||
wantPass: false, // Some columns exceed 5 chars
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateColumnNameLength(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReservedKeywords(t *testing.T) {
|
||||
// Create a database with reserved keywords
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "user", // "user" is a reserved keyword
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
"select": { // "select" is a reserved keyword
|
||||
Name: "select",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
wantLen int
|
||||
checkPasses bool
|
||||
}{
|
||||
{
|
||||
name: "check tables only",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: false,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 1, // "user" table
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check columns only",
|
||||
rule: Rule{
|
||||
CheckTables: false,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 2, // "id", "select" columns (id passes, select fails)
|
||||
checkPasses: false,
|
||||
},
|
||||
{
|
||||
name: "check both",
|
||||
rule: Rule{
|
||||
CheckTables: true,
|
||||
CheckColumns: true,
|
||||
Message: "Reserved keyword used",
|
||||
},
|
||||
wantLen: 3, // "user" table + "id", "select" columns
|
||||
checkPasses: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := validateReservedKeywords(db, tt.rule, "test_rule")
|
||||
if len(results) != tt.wantLen {
|
||||
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMissingPrimaryKey(t *testing.T) {
|
||||
// Create database with and without primary keys
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "with_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without_pk",
|
||||
Columns: map[string]*models.Column{
|
||||
"name": {
|
||||
Name: "name",
|
||||
Type: "varchar(50)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Table missing primary key",
|
||||
}
|
||||
|
||||
results := validateMissingPrimaryKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
|
||||
}
|
||||
|
||||
// First result should pass (with_pk has PK)
|
||||
if results[0].Passed != true {
|
||||
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
|
||||
}
|
||||
|
||||
// Second result should fail (without_pk missing PK)
|
||||
if results[1].Passed != false {
|
||||
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOrphanedForeignKey(t *testing.T) {
|
||||
// Create database with orphaned FK
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "users",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_nonexistent": {
|
||||
Name: "fk_nonexistent",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"rid_organization"},
|
||||
ReferencedTable: "nonexistent_table",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Foreign key references non-existent table",
|
||||
}
|
||||
|
||||
results := validateOrphanedForeignKey(db, rule, "test_rule")
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
|
||||
}
|
||||
|
||||
if results[0].Passed != false {
|
||||
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCircularDependency(t *testing.T) {
|
||||
// Create database with circular dependency
|
||||
db := &models.Database{
|
||||
Name: "testdb",
|
||||
Schemas: []*models.Schema{
|
||||
{
|
||||
Name: "public",
|
||||
Tables: []*models.Table{
|
||||
{
|
||||
Name: "table_a",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_b": {
|
||||
Name: "fk_to_b",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_b",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "table_b",
|
||||
Columns: map[string]*models.Column{
|
||||
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
|
||||
},
|
||||
Constraints: map[string]*models.Constraint{
|
||||
"fk_to_a": {
|
||||
Name: "fk_to_a",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
ReferencedTable: "table_a",
|
||||
ReferencedSchema: "public",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Message: "Circular dependency detected",
|
||||
}
|
||||
|
||||
results := validateCircularDependency(db, rule, "test_rule")
|
||||
|
||||
// Should detect circular dependency in both tables
|
||||
if len(results) == 0 {
|
||||
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
if result.Passed {
|
||||
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeDataType(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"varchar(50)", "varchar"},
|
||||
{"decimal(10,2)", "decimal"},
|
||||
{"int", "int"},
|
||||
{"BIGINT", "bigint"},
|
||||
{"VARCHAR(255)", "varchar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := normalizeDataType(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
value string
|
||||
expected bool
|
||||
}{
|
||||
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
|
||||
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
|
||||
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
|
||||
{"empty slice", []string{}, "foo", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.value)
|
||||
if result != tt.expected {
|
||||
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
graph map[string][]string
|
||||
node string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "simple cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"B"},
|
||||
"B": {"C"},
|
||||
"C": {},
|
||||
},
|
||||
node: "A",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "self cycle",
|
||||
graph: map[string][]string{
|
||||
"A": {"A"},
|
||||
},
|
||||
node: "A",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
visited := make(map[string]bool)
|
||||
recStack := make(map[string]bool)
|
||||
result := hasCycle(tt.node, tt.graph, visited, recStack)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatLocation(t *testing.T) {
|
||||
tests := []struct {
|
||||
schema string
|
||||
table string
|
||||
column string
|
||||
expected string
|
||||
}{
|
||||
{"public", "users", "id", "public.users.id"},
|
||||
{"public", "users", "", "public.users"},
|
||||
{"public", "", "", "public"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := formatLocation(tt.schema, tt.table, tt.column)
|
||||
if result != tt.expected {
|
||||
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
|
||||
tt.schema, tt.table, tt.column, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user