test(pgsql, reflectutil): add comprehensive test coverage
All checks were successful
CI / Test (1.24) (push) Successful in -26m14s
CI / Test (1.25) (push) Successful in -26m3s
CI / Lint (push) Successful in -26m28s
CI / Build (push) Successful in -26m41s
Integration Tests / Integration Tests (push) Successful in -26m21s

* Introduce tests for PostgreSQL data types and keywords.
* Implement tests for reflect utility functions.
* Ensure consistency and correctness of type conversions and keyword mappings.
* Validate behavior for various edge cases and input types.
This commit is contained in:
2026-01-31 22:30:00 +02:00
parent f2d500f98d
commit 5d9770b430
10 changed files with 4367 additions and 0 deletions

View File

@@ -0,0 +1,238 @@
package inspector
import (
"testing"
)
func TestNewInspector(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
if inspector == nil {
t.Fatal("NewInspector() returned nil")
}
if inspector.db != db {
t.Error("NewInspector() database not set correctly")
}
if inspector.config != config {
t.Error("NewInspector() config not set correctly")
}
}
func TestInspect(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() returned error: %v", err)
}
if report == nil {
t.Fatal("Inspect() returned nil report")
}
if report.Database != db.Name {
t.Errorf("Inspect() report.Database = %q, want %q", report.Database, db.Name)
}
if report.Summary.TotalRules != len(config.Rules) {
t.Errorf("Inspect() TotalRules = %d, want %d", report.Summary.TotalRules, len(config.Rules))
}
if len(report.Violations) == 0 {
t.Error("Inspect() returned no violations, expected some results")
}
}
func TestInspectWithDisabledRules(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
// Disable all rules
for name := range config.Rules {
rule := config.Rules[name]
rule.Enabled = "off"
config.Rules[name] = rule
}
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() with disabled rules returned error: %v", err)
}
if report.Summary.RulesChecked != 0 {
t.Errorf("Inspect() RulesChecked = %d, want 0 (all disabled)", report.Summary.RulesChecked)
}
if report.Summary.RulesSkipped != len(config.Rules) {
t.Errorf("Inspect() RulesSkipped = %d, want %d", report.Summary.RulesSkipped, len(config.Rules))
}
}
func TestInspectWithEnforcedRules(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
// Enable only one rule and enforce it
for name := range config.Rules {
rule := config.Rules[name]
rule.Enabled = "off"
config.Rules[name] = rule
}
primaryKeyRule := config.Rules["primary_key_naming"]
primaryKeyRule.Enabled = "enforce"
primaryKeyRule.Pattern = "^id$"
config.Rules["primary_key_naming"] = primaryKeyRule
inspector := NewInspector(db, config)
report, err := inspector.Inspect()
if err != nil {
t.Fatalf("Inspect() returned error: %v", err)
}
if report.Summary.RulesChecked != 1 {
t.Errorf("Inspect() RulesChecked = %d, want 1", report.Summary.RulesChecked)
}
// All results should be at error level for enforced rules
for _, violation := range report.Violations {
if violation.Level != "error" {
t.Errorf("Enforced rule violation has Level = %q, want \"error\"", violation.Level)
}
}
}
func TestGenerateSummary(t *testing.T) {
db := createTestDatabase()
config := GetDefaultConfig()
inspector := NewInspector(db, config)
results := []ValidationResult{
{RuleName: "rule1", Passed: true, Level: "error"},
{RuleName: "rule2", Passed: false, Level: "error"},
{RuleName: "rule3", Passed: false, Level: "warning"},
{RuleName: "rule4", Passed: true, Level: "warning"},
}
summary := inspector.generateSummary(results)
if summary.PassedCount != 2 {
t.Errorf("generateSummary() PassedCount = %d, want 2", summary.PassedCount)
}
if summary.ErrorCount != 1 {
t.Errorf("generateSummary() ErrorCount = %d, want 1", summary.ErrorCount)
}
if summary.WarningCount != 1 {
t.Errorf("generateSummary() WarningCount = %d, want 1", summary.WarningCount)
}
}
func TestHasErrors(t *testing.T) {
tests := []struct {
name string
report *InspectorReport
want bool
}{
{
name: "with errors",
report: &InspectorReport{
Summary: ReportSummary{
ErrorCount: 5,
},
},
want: true,
},
{
name: "without errors",
report: &InspectorReport{
Summary: ReportSummary{
ErrorCount: 0,
WarningCount: 3,
},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.report.HasErrors(); got != tt.want {
t.Errorf("HasErrors() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetValidator(t *testing.T) {
tests := []struct {
name string
functionName string
wantExists bool
}{
{"primary_key_naming", "primary_key_naming", true},
{"primary_key_datatype", "primary_key_datatype", true},
{"foreign_key_column_naming", "foreign_key_column_naming", true},
{"table_regexpr", "table_regexpr", true},
{"column_regexpr", "column_regexpr", true},
{"reserved_words", "reserved_words", true},
{"have_primary_key", "have_primary_key", true},
{"orphaned_foreign_key", "orphaned_foreign_key", true},
{"circular_dependency", "circular_dependency", true},
{"unknown_function", "unknown_function", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, exists := getValidator(tt.functionName)
if exists != tt.wantExists {
t.Errorf("getValidator(%q) exists = %v, want %v", tt.functionName, exists, tt.wantExists)
}
})
}
}
func TestCreateResult(t *testing.T) {
result := createResult(
"test_rule",
true,
"Test message",
"schema.table.column",
map[string]interface{}{
"key1": "value1",
"key2": 42,
},
)
if result.RuleName != "test_rule" {
t.Errorf("createResult() RuleName = %q, want \"test_rule\"", result.RuleName)
}
if !result.Passed {
t.Error("createResult() Passed = false, want true")
}
if result.Message != "Test message" {
t.Errorf("createResult() Message = %q, want \"Test message\"", result.Message)
}
if result.Location != "schema.table.column" {
t.Errorf("createResult() Location = %q, want \"schema.table.column\"", result.Location)
}
if len(result.Context) != 2 {
t.Errorf("createResult() Context length = %d, want 2", len(result.Context))
}
}

View File

@@ -0,0 +1,366 @@
package inspector
import (
"bytes"
"encoding/json"
"strings"
"testing"
"time"
)
func createTestReport() *InspectorReport {
return &InspectorReport{
Summary: ReportSummary{
TotalRules: 10,
RulesChecked: 8,
RulesSkipped: 2,
ErrorCount: 3,
WarningCount: 5,
PassedCount: 12,
},
Violations: []ValidationResult{
{
RuleName: "primary_key_naming",
Level: "error",
Message: "Primary key should start with 'id_'",
Location: "public.users.user_id",
Passed: false,
Context: map[string]interface{}{
"schema": "public",
"table": "users",
"column": "user_id",
"pattern": "^id_",
},
},
{
RuleName: "table_name_length",
Level: "warning",
Message: "Table name too long",
Location: "public.very_long_table_name_that_exceeds_limits",
Passed: false,
Context: map[string]interface{}{
"schema": "public",
"table": "very_long_table_name_that_exceeds_limits",
"length": 44,
"max_length": 32,
},
},
},
GeneratedAt: time.Now(),
Database: "testdb",
SourceFormat: "postgresql",
}
}
func TestNewMarkdownFormatter(t *testing.T) {
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
if formatter == nil {
t.Fatal("NewMarkdownFormatter() returned nil")
}
// Buffer is not a terminal, so colors should be disabled
if formatter.UseColors {
t.Error("NewMarkdownFormatter() UseColors should be false for non-terminal")
}
}
func TestNewJSONFormatter(t *testing.T) {
formatter := NewJSONFormatter()
if formatter == nil {
t.Fatal("NewJSONFormatter() returned nil")
}
}
func TestMarkdownFormatter_Format(t *testing.T) {
report := createTestReport()
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
}
// Check that output contains expected sections
if !strings.Contains(output, "# RelSpec Inspector Report") {
t.Error("Markdown output missing header")
}
if !strings.Contains(output, "Database:") {
t.Error("Markdown output missing database field")
}
if !strings.Contains(output, "testdb") {
t.Error("Markdown output missing database name")
}
if !strings.Contains(output, "Summary") {
t.Error("Markdown output missing summary section")
}
if !strings.Contains(output, "Rules Checked: 8") {
t.Error("Markdown output missing rules checked count")
}
if !strings.Contains(output, "Errors: 3") {
t.Error("Markdown output missing error count")
}
if !strings.Contains(output, "Warnings: 5") {
t.Error("Markdown output missing warning count")
}
if !strings.Contains(output, "Violations") {
t.Error("Markdown output missing violations section")
}
if !strings.Contains(output, "primary_key_naming") {
t.Error("Markdown output missing rule name")
}
if !strings.Contains(output, "public.users.user_id") {
t.Error("Markdown output missing location")
}
}
func TestMarkdownFormatter_FormatNoViolations(t *testing.T) {
report := &InspectorReport{
Summary: ReportSummary{
TotalRules: 10,
RulesChecked: 10,
RulesSkipped: 0,
ErrorCount: 0,
WarningCount: 0,
PassedCount: 50,
},
Violations: []ValidationResult{},
GeneratedAt: time.Now(),
Database: "testdb",
SourceFormat: "postgresql",
}
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("MarkdownFormatter.Format() returned error: %v", err)
}
if !strings.Contains(output, "No violations found") {
t.Error("Markdown output should indicate no violations")
}
}
func TestJSONFormatter_Format(t *testing.T) {
report := createTestReport()
formatter := NewJSONFormatter()
output, err := formatter.Format(report)
if err != nil {
t.Fatalf("JSONFormatter.Format() returned error: %v", err)
}
// Verify it's valid JSON
var decoded InspectorReport
if err := json.Unmarshal([]byte(output), &decoded); err != nil {
t.Fatalf("JSONFormatter.Format() produced invalid JSON: %v", err)
}
// Check key fields
if decoded.Database != "testdb" {
t.Errorf("JSON decoded Database = %q, want \"testdb\"", decoded.Database)
}
if decoded.Summary.ErrorCount != 3 {
t.Errorf("JSON decoded ErrorCount = %d, want 3", decoded.Summary.ErrorCount)
}
if len(decoded.Violations) != 2 {
t.Errorf("JSON decoded Violations length = %d, want 2", len(decoded.Violations))
}
}
func TestMarkdownFormatter_FormatHeader(t *testing.T) {
var buf bytes.Buffer
formatter := NewMarkdownFormatter(&buf)
header := formatter.formatHeader("Test Header")
if !strings.Contains(header, "# Test Header") {
t.Errorf("formatHeader() = %q, want to contain \"# Test Header\"", header)
}
}
func TestMarkdownFormatter_FormatBold(t *testing.T) {
tests := []struct {
name string
useColors bool
text string
wantContains string
}{
{
name: "without colors",
useColors: false,
text: "Bold Text",
wantContains: "**Bold Text**",
},
{
name: "with colors",
useColors: true,
text: "Bold Text",
wantContains: "Bold Text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: tt.useColors}
result := formatter.formatBold(tt.text)
if !strings.Contains(result, tt.wantContains) {
t.Errorf("formatBold() = %q, want to contain %q", result, tt.wantContains)
}
})
}
}
func TestMarkdownFormatter_Colorize(t *testing.T) {
tests := []struct {
name string
useColors bool
text string
color string
wantColor bool
}{
{
name: "without colors",
useColors: false,
text: "Test",
color: colorRed,
wantColor: false,
},
{
name: "with colors",
useColors: true,
text: "Test",
color: colorRed,
wantColor: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: tt.useColors}
result := formatter.colorize(tt.text, tt.color)
hasColor := strings.Contains(result, tt.color)
if hasColor != tt.wantColor {
t.Errorf("colorize() has color codes = %v, want %v", hasColor, tt.wantColor)
}
if !strings.Contains(result, tt.text) {
t.Errorf("colorize() doesn't contain original text %q", tt.text)
}
})
}
}
func TestMarkdownFormatter_FormatContext(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: false}
context := map[string]interface{}{
"schema": "public",
"table": "users",
"column": "id",
"pattern": "^id_",
"max_length": 64,
}
result := formatter.formatContext(context)
// Should not include schema, table, column (they're in location)
if strings.Contains(result, "schema") {
t.Error("formatContext() should skip schema field")
}
if strings.Contains(result, "table=") {
t.Error("formatContext() should skip table field")
}
if strings.Contains(result, "column=") {
t.Error("formatContext() should skip column field")
}
// Should include other fields
if !strings.Contains(result, "pattern") {
t.Error("formatContext() should include pattern field")
}
if !strings.Contains(result, "max_length") {
t.Error("formatContext() should include max_length field")
}
}
func TestMarkdownFormatter_FormatViolation(t *testing.T) {
formatter := &MarkdownFormatter{UseColors: false}
violation := ValidationResult{
RuleName: "test_rule",
Level: "error",
Message: "Test violation message",
Location: "public.users.id",
Passed: false,
Context: map[string]interface{}{
"pattern": "^id_",
},
}
result := formatter.formatViolation(violation, colorRed)
if !strings.Contains(result, "test_rule") {
t.Error("formatViolation() should include rule name")
}
if !strings.Contains(result, "Test violation message") {
t.Error("formatViolation() should include message")
}
if !strings.Contains(result, "public.users.id") {
t.Error("formatViolation() should include location")
}
if !strings.Contains(result, "Location:") {
t.Error("formatViolation() should include Location label")
}
if !strings.Contains(result, "Message:") {
t.Error("formatViolation() should include Message label")
}
}
func TestReportFormatConstants(t *testing.T) {
// Test that color constants are defined
if colorReset == "" {
t.Error("colorReset is not defined")
}
if colorRed == "" {
t.Error("colorRed is not defined")
}
if colorYellow == "" {
t.Error("colorYellow is not defined")
}
if colorGreen == "" {
t.Error("colorGreen is not defined")
}
if colorBold == "" {
t.Error("colorBold is not defined")
}
}

249
pkg/inspector/rules_test.go Normal file
View File

@@ -0,0 +1,249 @@
package inspector
import (
"os"
"path/filepath"
"testing"
)
func TestGetDefaultConfig(t *testing.T) {
config := GetDefaultConfig()
if config == nil {
t.Fatal("GetDefaultConfig() returned nil")
}
if config.Version != "1.0" {
t.Errorf("GetDefaultConfig() Version = %q, want \"1.0\"", config.Version)
}
if len(config.Rules) == 0 {
t.Error("GetDefaultConfig() returned no rules")
}
// Check that all expected rules are present
expectedRules := []string{
"primary_key_naming",
"primary_key_datatype",
"primary_key_auto_increment",
"foreign_key_column_naming",
"foreign_key_constraint_naming",
"foreign_key_index",
"table_naming_case",
"column_naming_case",
"table_name_length",
"column_name_length",
"reserved_keywords",
"missing_primary_key",
"orphaned_foreign_key",
"circular_dependency",
}
for _, ruleName := range expectedRules {
if _, exists := config.Rules[ruleName]; !exists {
t.Errorf("GetDefaultConfig() missing rule: %q", ruleName)
}
}
}
func TestLoadConfig_NonExistentFile(t *testing.T) {
// Try to load a non-existent file
config, err := LoadConfig("/path/to/nonexistent/file.yaml")
if err != nil {
t.Fatalf("LoadConfig() with non-existent file returned error: %v", err)
}
// Should return default config
if config == nil {
t.Fatal("LoadConfig() returned nil config for non-existent file")
}
if len(config.Rules) == 0 {
t.Error("LoadConfig() returned config with no rules")
}
}
func TestLoadConfig_ValidFile(t *testing.T) {
// Create a temporary config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test-config.yaml")
configContent := `version: "1.0"
rules:
primary_key_naming:
enabled: "enforce"
function: "primary_key_naming"
pattern: "^pk_"
message: "Primary keys must start with pk_"
table_name_length:
enabled: "warn"
function: "table_name_length"
max_length: 50
message: "Table name too long"
`
err := os.WriteFile(configPath, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
config, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() returned error: %v", err)
}
if config.Version != "1.0" {
t.Errorf("LoadConfig() Version = %q, want \"1.0\"", config.Version)
}
if len(config.Rules) != 2 {
t.Errorf("LoadConfig() loaded %d rules, want 2", len(config.Rules))
}
// Check primary_key_naming rule
pkRule, exists := config.Rules["primary_key_naming"]
if !exists {
t.Fatal("LoadConfig() missing primary_key_naming rule")
}
if pkRule.Enabled != "enforce" {
t.Errorf("primary_key_naming.Enabled = %q, want \"enforce\"", pkRule.Enabled)
}
if pkRule.Pattern != "^pk_" {
t.Errorf("primary_key_naming.Pattern = %q, want \"^pk_\"", pkRule.Pattern)
}
// Check table_name_length rule
lengthRule, exists := config.Rules["table_name_length"]
if !exists {
t.Fatal("LoadConfig() missing table_name_length rule")
}
if lengthRule.MaxLength != 50 {
t.Errorf("table_name_length.MaxLength = %d, want 50", lengthRule.MaxLength)
}
}
func TestLoadConfig_InvalidYAML(t *testing.T) {
// Create a temporary invalid config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "invalid-config.yaml")
invalidContent := `invalid: yaml: content: {[}]`
err := os.WriteFile(configPath, []byte(invalidContent), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
_, err = LoadConfig(configPath)
if err == nil {
t.Error("LoadConfig() with invalid YAML did not return error")
}
}
func TestRuleIsEnabled(t *testing.T) {
tests := []struct {
name string
rule Rule
want bool
}{
{
name: "enforce is enabled",
rule: Rule{Enabled: "enforce"},
want: true,
},
{
name: "warn is enabled",
rule: Rule{Enabled: "warn"},
want: true,
},
{
name: "off is not enabled",
rule: Rule{Enabled: "off"},
want: false,
},
{
name: "empty is not enabled",
rule: Rule{Enabled: ""},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.rule.IsEnabled(); got != tt.want {
t.Errorf("Rule.IsEnabled() = %v, want %v", got, tt.want)
}
})
}
}
func TestRuleIsEnforced(t *testing.T) {
tests := []struct {
name string
rule Rule
want bool
}{
{
name: "enforce is enforced",
rule: Rule{Enabled: "enforce"},
want: true,
},
{
name: "warn is not enforced",
rule: Rule{Enabled: "warn"},
want: false,
},
{
name: "off is not enforced",
rule: Rule{Enabled: "off"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.rule.IsEnforced(); got != tt.want {
t.Errorf("Rule.IsEnforced() = %v, want %v", got, tt.want)
}
})
}
}
func TestDefaultConfigRuleSettings(t *testing.T) {
config := GetDefaultConfig()
// Test specific rule settings
pkNamingRule := config.Rules["primary_key_naming"]
if pkNamingRule.Function != "primary_key_naming" {
t.Errorf("primary_key_naming.Function = %q, want \"primary_key_naming\"", pkNamingRule.Function)
}
if pkNamingRule.Pattern != "^id_" {
t.Errorf("primary_key_naming.Pattern = %q, want \"^id_\"", pkNamingRule.Pattern)
}
// Test datatype rule
pkDatatypeRule := config.Rules["primary_key_datatype"]
if len(pkDatatypeRule.AllowedTypes) == 0 {
t.Error("primary_key_datatype has no allowed types")
}
// Test length rule
tableLengthRule := config.Rules["table_name_length"]
if tableLengthRule.MaxLength != 64 {
t.Errorf("table_name_length.MaxLength = %d, want 64", tableLengthRule.MaxLength)
}
// Test reserved keywords rule
reservedRule := config.Rules["reserved_keywords"]
if !reservedRule.CheckTables {
t.Error("reserved_keywords.CheckTables should be true")
}
if !reservedRule.CheckColumns {
t.Error("reserved_keywords.CheckColumns should be true")
}
}

View File

@@ -0,0 +1,837 @@
package inspector
import (
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// Helper function to create test database
func createTestDatabase() *models.Database {
return &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigserial",
IsPrimaryKey: true,
AutoIncrement: true,
},
"username": {
Name: "username",
Type: "varchar(50)",
NotNull: true,
IsPrimaryKey: false,
},
"rid_organization": {
Name: "rid_organization",
Type: "bigint",
NotNull: true,
IsPrimaryKey: false,
},
},
Constraints: map[string]*models.Constraint{
"fk_users_organization": {
Name: "fk_users_organization",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_organization"},
ReferencedTable: "organizations",
ReferencedSchema: "public",
ReferencedColumns: []string{"id"},
},
},
Indexes: map[string]*models.Index{
"idx_rid_organization": {
Name: "idx_rid_organization",
Columns: []string{"rid_organization"},
},
},
},
{
Name: "organizations",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigserial",
IsPrimaryKey: true,
AutoIncrement: true,
},
"name": {
Name: "name",
Type: "varchar(100)",
NotNull: true,
IsPrimaryKey: false,
},
},
},
},
},
},
}
}
func TestValidatePrimaryKeyNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern id",
rule: Rule{
Pattern: "^id$",
Message: "Primary key should be 'id'",
},
wantLen: 2,
wantPass: true,
},
{
name: "non-matching pattern id_",
rule: Rule{
Pattern: "^id_",
Message: "Primary key should start with 'id_'",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validatePrimaryKeyNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidatePrimaryKeyDatatype(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "allowed type bigserial",
rule: Rule{
AllowedTypes: []string{"bigserial", "bigint", "int"},
Message: "Primary key should use integer types",
},
wantLen: 2,
wantPass: true,
},
{
name: "disallowed type",
rule: Rule{
AllowedTypes: []string{"uuid"},
Message: "Primary key should use UUID",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyDatatype(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyDatatype() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validatePrimaryKeyDatatype() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidatePrimaryKeyAutoIncrement(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
}{
{
name: "require auto increment",
rule: Rule{
RequireAutoIncrement: true,
Message: "Primary key should have auto-increment",
},
wantLen: 0, // No violations - all PKs have auto-increment
},
{
name: "disallow auto increment",
rule: Rule{
RequireAutoIncrement: false,
Message: "Primary key should not have auto-increment",
},
wantLen: 2, // 2 violations - both PKs have auto-increment
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validatePrimaryKeyAutoIncrement(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validatePrimaryKeyAutoIncrement() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateForeignKeyColumnNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern rid_",
rule: Rule{
Pattern: "^rid_",
Message: "Foreign key columns should start with 'rid_'",
},
wantLen: 1,
wantPass: true,
},
{
name: "non-matching pattern fk_",
rule: Rule{
Pattern: "^fk_",
Message: "Foreign key columns should start with 'fk_'",
},
wantLen: 1,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyColumnNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyColumnNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyColumnNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateForeignKeyConstraintNaming(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "matching pattern fk_",
rule: Rule{
Pattern: "^fk_",
Message: "Foreign key constraints should start with 'fk_'",
},
wantLen: 1,
wantPass: true,
},
{
name: "non-matching pattern FK_",
rule: Rule{
Pattern: "^FK_",
Message: "Foreign key constraints should start with 'FK_'",
},
wantLen: 1,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyConstraintNaming(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyConstraintNaming() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyConstraintNaming() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateForeignKeyIndex(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "require index with index present",
rule: Rule{
RequireIndex: true,
Message: "Foreign key columns should have indexes",
},
wantLen: 1,
wantPass: true,
},
{
name: "no requirement",
rule: Rule{
RequireIndex: false,
Message: "Foreign key index check disabled",
},
wantLen: 0,
wantPass: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateForeignKeyIndex(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateForeignKeyIndex() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateForeignKeyIndex() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateTableNamingCase(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "lowercase snake_case pattern",
rule: Rule{
Pattern: "^[a-z][a-z0-9_]*$",
Case: "lowercase",
Message: "Table names should be lowercase snake_case",
},
wantLen: 2,
wantPass: true,
},
{
name: "uppercase pattern",
rule: Rule{
Pattern: "^[A-Z][A-Z0-9_]*$",
Case: "uppercase",
Message: "Table names should be uppercase",
},
wantLen: 2,
wantPass: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateTableNamingCase(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateTableNamingCase() returned %d results, want %d", len(results), tt.wantLen)
}
if len(results) > 0 && results[0].Passed != tt.wantPass {
t.Errorf("validateTableNamingCase() passed=%v, want %v", results[0].Passed, tt.wantPass)
}
})
}
}
func TestValidateColumnNamingCase(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "lowercase snake_case pattern",
rule: Rule{
Pattern: "^[a-z][a-z0-9_]*$",
Case: "lowercase",
Message: "Column names should be lowercase snake_case",
},
wantLen: 5, // 5 total columns across both tables
wantPass: true,
},
{
name: "camelCase pattern",
rule: Rule{
Pattern: "^[a-z][a-zA-Z0-9]*$",
Case: "camelCase",
Message: "Column names should be camelCase",
},
wantLen: 5,
wantPass: false, // rid_organization has underscore
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateColumnNamingCase(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateColumnNamingCase() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateTableNameLength(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "max length 64",
rule: Rule{
MaxLength: 64,
Message: "Table name too long",
},
wantLen: 2,
wantPass: true,
},
{
name: "max length 5",
rule: Rule{
MaxLength: 5,
Message: "Table name too long",
},
wantLen: 2,
wantPass: false, // "users" is 5 chars (passes), "organizations" is 13 (fails)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateTableNameLength(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateTableNameLength() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateColumnNameLength(t *testing.T) {
db := createTestDatabase()
tests := []struct {
name string
rule Rule
wantLen int
wantPass bool
}{
{
name: "max length 64",
rule: Rule{
MaxLength: 64,
Message: "Column name too long",
},
wantLen: 5,
wantPass: true,
},
{
name: "max length 5",
rule: Rule{
MaxLength: 5,
Message: "Column name too long",
},
wantLen: 5,
wantPass: false, // Some columns exceed 5 chars
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateColumnNameLength(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateColumnNameLength() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateReservedKeywords(t *testing.T) {
// Create a database with reserved keywords
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "user", // "user" is a reserved keyword
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
"select": { // "select" is a reserved keyword
Name: "select",
Type: "varchar(50)",
},
},
},
},
},
},
}
tests := []struct {
name string
rule Rule
wantLen int
checkPasses bool
}{
{
name: "check tables only",
rule: Rule{
CheckTables: true,
CheckColumns: false,
Message: "Reserved keyword used",
},
wantLen: 1, // "user" table
checkPasses: false,
},
{
name: "check columns only",
rule: Rule{
CheckTables: false,
CheckColumns: true,
Message: "Reserved keyword used",
},
wantLen: 2, // "id", "select" columns (id passes, select fails)
checkPasses: false,
},
{
name: "check both",
rule: Rule{
CheckTables: true,
CheckColumns: true,
Message: "Reserved keyword used",
},
wantLen: 3, // "user" table + "id", "select" columns
checkPasses: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := validateReservedKeywords(db, tt.rule, "test_rule")
if len(results) != tt.wantLen {
t.Errorf("validateReservedKeywords() returned %d results, want %d", len(results), tt.wantLen)
}
})
}
}
func TestValidateMissingPrimaryKey(t *testing.T) {
// Create database with and without primary keys
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "with_pk",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
},
},
{
Name: "without_pk",
Columns: map[string]*models.Column{
"name": {
Name: "name",
Type: "varchar(50)",
},
},
},
},
},
},
}
rule := Rule{
Message: "Table missing primary key",
}
results := validateMissingPrimaryKey(db, rule, "test_rule")
if len(results) != 2 {
t.Errorf("validateMissingPrimaryKey() returned %d results, want 2", len(results))
}
// First result should pass (with_pk has PK)
if results[0].Passed != true {
t.Errorf("validateMissingPrimaryKey() result[0].Passed=%v, want true", results[0].Passed)
}
// Second result should fail (without_pk missing PK)
if results[1].Passed != false {
t.Errorf("validateMissingPrimaryKey() result[1].Passed=%v, want false", results[1].Passed)
}
}
func TestValidateOrphanedForeignKey(t *testing.T) {
// Create database with orphaned FK
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "users",
Columns: map[string]*models.Column{
"id": {
Name: "id",
Type: "bigint",
IsPrimaryKey: true,
},
},
Constraints: map[string]*models.Constraint{
"fk_nonexistent": {
Name: "fk_nonexistent",
Type: models.ForeignKeyConstraint,
Columns: []string{"rid_organization"},
ReferencedTable: "nonexistent_table",
ReferencedSchema: "public",
},
},
},
},
},
},
}
rule := Rule{
Message: "Foreign key references non-existent table",
}
results := validateOrphanedForeignKey(db, rule, "test_rule")
if len(results) != 1 {
t.Errorf("validateOrphanedForeignKey() returned %d results, want 1", len(results))
}
if results[0].Passed != false {
t.Errorf("validateOrphanedForeignKey() passed=%v, want false", results[0].Passed)
}
}
func TestValidateCircularDependency(t *testing.T) {
// Create database with circular dependency
db := &models.Database{
Name: "testdb",
Schemas: []*models.Schema{
{
Name: "public",
Tables: []*models.Table{
{
Name: "table_a",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
},
Constraints: map[string]*models.Constraint{
"fk_to_b": {
Name: "fk_to_b",
Type: models.ForeignKeyConstraint,
ReferencedTable: "table_b",
ReferencedSchema: "public",
},
},
},
{
Name: "table_b",
Columns: map[string]*models.Column{
"id": {Name: "id", Type: "bigint", IsPrimaryKey: true},
},
Constraints: map[string]*models.Constraint{
"fk_to_a": {
Name: "fk_to_a",
Type: models.ForeignKeyConstraint,
ReferencedTable: "table_a",
ReferencedSchema: "public",
},
},
},
},
},
},
}
rule := Rule{
Message: "Circular dependency detected",
}
results := validateCircularDependency(db, rule, "test_rule")
// Should detect circular dependency in both tables
if len(results) == 0 {
t.Error("validateCircularDependency() returned 0 results, expected circular dependency detection")
}
for _, result := range results {
if result.Passed {
t.Error("validateCircularDependency() passed=true, want false for circular dependency")
}
}
}
func TestNormalizeDataType(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"varchar(50)", "varchar"},
{"decimal(10,2)", "decimal"},
{"int", "int"},
{"BIGINT", "bigint"},
{"VARCHAR(255)", "varchar"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := normalizeDataType(tt.input)
if result != tt.expected {
t.Errorf("normalizeDataType(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
value string
expected bool
}{
{"found exact", []string{"foo", "bar", "baz"}, "bar", true},
{"not found", []string{"foo", "bar", "baz"}, "qux", false},
{"case insensitive match", []string{"foo", "Bar", "baz"}, "bar", true},
{"empty slice", []string{}, "foo", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.value)
if result != tt.expected {
t.Errorf("contains(%v, %q) = %v, want %v", tt.slice, tt.value, result, tt.expected)
}
})
}
}
func TestHasCycle(t *testing.T) {
tests := []struct {
name string
graph map[string][]string
node string
expected bool
}{
{
name: "simple cycle",
graph: map[string][]string{
"A": {"B"},
"B": {"C"},
"C": {"A"},
},
node: "A",
expected: true,
},
{
name: "no cycle",
graph: map[string][]string{
"A": {"B"},
"B": {"C"},
"C": {},
},
node: "A",
expected: false,
},
{
name: "self cycle",
graph: map[string][]string{
"A": {"A"},
},
node: "A",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
visited := make(map[string]bool)
recStack := make(map[string]bool)
result := hasCycle(tt.node, tt.graph, visited, recStack)
if result != tt.expected {
t.Errorf("hasCycle() = %v, want %v", result, tt.expected)
}
})
}
}
func TestFormatLocation(t *testing.T) {
tests := []struct {
schema string
table string
column string
expected string
}{
{"public", "users", "id", "public.users.id"},
{"public", "users", "", "public.users"},
{"public", "", "", "public"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := formatLocation(tt.schema, tt.table, tt.column)
if result != tt.expected {
t.Errorf("formatLocation(%q, %q, %q) = %q, want %q",
tt.schema, tt.table, tt.column, result, tt.expected)
}
})
}
}