package middleware import ( "net/http" "net/http/httptest" "testing" ) func TestSanitizeXSS(t *testing.T) { sanitizer := DefaultSanitizer() tests := []struct { name string input string contains string // String that should NOT be in output }{ { name: "Script tag", input: "", contains: "John", "email": "test@example.com", "nested": map[string]interface{}{ "bio": "", }, } result := sanitizer.SanitizeMap(input) // Check that script tag was removed/escaped name, ok := result["name"].(string) if !ok || name == input["name"] { t.Error("Name should be sanitized") } // Check nested map nested, ok := result["nested"].(map[string]interface{}) if !ok { t.Fatal("Nested should still be a map") } bio, ok := nested["bio"].(string) if !ok || bio == input["nested"].(map[string]interface{})["bio"] { t.Error("Nested bio should be sanitized") } } func TestSanitizeMiddleware(t *testing.T) { sanitizer := DefaultSanitizer() handler := sanitizer.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check that query param was sanitized param := r.URL.Query().Get("q") if param == "" { t.Error("Query param should be sanitized") } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/test?q=", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Handler failed: got %d, want %d", w.Code, http.StatusOK) } } func TestSanitizeFilename(t *testing.T) { tests := []struct { name string input string contains string // String that should NOT be in output }{ { name: "Path traversal", input: "../../../etc/passwd", contains: "..", }, { name: "Absolute path", input: "/etc/passwd", contains: "/", }, { name: "Windows path", input: "..\\..\\windows\\system32", contains: "\\", }, { name: "Null byte", input: "file\x00.txt", contains: "\x00", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := SanitizeFilename(tt.input) if result == tt.input { t.Errorf("SanitizeFilename() did not modify input: %q", tt.input) } }) } } func TestSanitizeEmail(t *testing.T) { tests := []struct { name string input string expected string }{ { name: "Uppercase", input: "TEST@EXAMPLE.COM", expected: "test@example.com", }, { name: "Whitespace", input: " test@example.com ", expected: "test@example.com", }, { name: "Null bytes", input: "test\x00@example.com", expected: "test@example.com", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := SanitizeEmail(tt.input) if result != tt.expected { t.Errorf("SanitizeEmail() = %q, want %q", result, tt.expected) } }) } } func TestSanitizeURL(t *testing.T) { tests := []struct { name string input string expected string }{ { name: "JavaScript protocol", input: "javascript:alert(1)", expected: "", }, { name: "Data protocol", input: "data:text/html,", expected: "", }, { name: "Valid HTTP URL", input: "https://example.com", expected: "https://example.com", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := SanitizeURL(tt.input) if result != tt.expected { t.Errorf("SanitizeURL() = %q, want %q", result, tt.expected) } }) } } func TestStrictSanitizer(t *testing.T) { sanitizer := StrictSanitizer() input := "Bold text with " result := sanitizer.Sanitize(input) // Should strip ALL HTML tags if result == input { t.Error("Strict sanitizer should modify input") } // Should not contain any HTML tags if len(result) > 0 && (result[0] == '<' || result[len(result)-1] == '>') { t.Error("Result should not contain HTML tags") } } func TestMaxStringLength(t *testing.T) { sanitizer := &Sanitizer{ MaxStringLength: 10, } input := "This is a very long string that exceeds the maximum length" result := sanitizer.Sanitize(input) if len(result) != 10 { t.Errorf("Result length = %d, want 10", len(result)) } if result != input[:10] { t.Errorf("Result = %q, want %q", result, input[:10]) } }