mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
255 lines
6.1 KiB
Go
255 lines
6.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestIPBlacklist_BlockIP(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
|
|
// Block an IP
|
|
err := bl.BlockIP("192.168.1.100", "Suspicious activity")
|
|
if err != nil {
|
|
t.Fatalf("BlockIP() error = %v", err)
|
|
}
|
|
|
|
// Check if IP is blocked
|
|
blocked, reason := bl.IsBlocked("192.168.1.100")
|
|
if !blocked {
|
|
t.Error("IP should be blocked")
|
|
}
|
|
if reason != "Suspicious activity" {
|
|
t.Errorf("Reason = %q, want %q", reason, "Suspicious activity")
|
|
}
|
|
|
|
// Check non-blocked IP
|
|
blocked, _ = bl.IsBlocked("192.168.1.1")
|
|
if blocked {
|
|
t.Error("IP should not be blocked")
|
|
}
|
|
}
|
|
|
|
func TestIPBlacklist_BlockCIDR(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
|
|
// Block a CIDR range
|
|
err := bl.BlockCIDR("10.0.0.0/24", "Internal network blocked")
|
|
if err != nil {
|
|
t.Fatalf("BlockCIDR() error = %v", err)
|
|
}
|
|
|
|
// Check IPs in range
|
|
testIPs := []string{
|
|
"10.0.0.1",
|
|
"10.0.0.100",
|
|
"10.0.0.254",
|
|
}
|
|
|
|
for _, ip := range testIPs {
|
|
blocked, _ := bl.IsBlocked(ip)
|
|
if !blocked {
|
|
t.Errorf("IP %s should be blocked by CIDR", ip)
|
|
}
|
|
}
|
|
|
|
// Check IP outside range
|
|
blocked, _ := bl.IsBlocked("10.0.1.1")
|
|
if blocked {
|
|
t.Error("IP outside CIDR range should not be blocked")
|
|
}
|
|
}
|
|
|
|
func TestIPBlacklist_UnblockIP(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
|
|
// Block and then unblock
|
|
bl.BlockIP("192.168.1.100", "Test")
|
|
|
|
blocked, _ := bl.IsBlocked("192.168.1.100")
|
|
if !blocked {
|
|
t.Error("IP should be blocked")
|
|
}
|
|
|
|
bl.UnblockIP("192.168.1.100")
|
|
|
|
blocked, _ = bl.IsBlocked("192.168.1.100")
|
|
if blocked {
|
|
t.Error("IP should be unblocked")
|
|
}
|
|
}
|
|
|
|
func TestIPBlacklist_UnblockCIDR(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
|
|
// Block and then unblock CIDR
|
|
bl.BlockCIDR("10.0.0.0/24", "Test")
|
|
|
|
blocked, _ := bl.IsBlocked("10.0.0.1")
|
|
if !blocked {
|
|
t.Error("IP should be blocked by CIDR")
|
|
}
|
|
|
|
bl.UnblockCIDR("10.0.0.0/24")
|
|
|
|
blocked, _ = bl.IsBlocked("10.0.0.1")
|
|
if blocked {
|
|
t.Error("IP should be unblocked after CIDR removal")
|
|
}
|
|
}
|
|
|
|
func TestIPBlacklist_Middleware(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
bl.BlockIP("192.168.1.100", "Banned")
|
|
|
|
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
}))
|
|
|
|
// Blocked IP should get 403
|
|
t.Run("BlockedIP", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.100:12345"
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusForbidden {
|
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
|
}
|
|
|
|
var response map[string]interface{}
|
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
|
t.Fatalf("Failed to parse response: %v", err)
|
|
}
|
|
|
|
if response["error"] != "forbidden" {
|
|
t.Errorf("Error = %v, want %q", response["error"], "forbidden")
|
|
}
|
|
})
|
|
|
|
// Allowed IP should succeed
|
|
t.Run("AllowedIP", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestIPBlacklist_MiddlewareWithProxy(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: true})
|
|
bl.BlockIP("203.0.113.1", "Blocked via proxy")
|
|
|
|
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Test X-Forwarded-For
|
|
t.Run("X-Forwarded-For", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.1:12345"
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.1")
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusForbidden {
|
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
|
|
// Test X-Real-IP
|
|
t.Run("X-Real-IP", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "10.0.0.1:12345"
|
|
req.Header.Set("X-Real-IP", "203.0.113.1")
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusForbidden {
|
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestIPBlacklist_StatsHandler(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
bl.BlockIP("192.168.1.100", "Test1")
|
|
bl.BlockIP("192.168.1.101", "Test2")
|
|
bl.BlockCIDR("10.0.0.0/24", "Test CIDR")
|
|
|
|
handler := bl.StatsHandler()
|
|
|
|
req := httptest.NewRequest("GET", "/blacklist-stats", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
|
|
var stats map[string]interface{}
|
|
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
|
t.Fatalf("Failed to parse response: %v", err)
|
|
}
|
|
|
|
if int(stats["total_ips"].(float64)) != 2 {
|
|
t.Errorf("total_ips = %v, want 2", stats["total_ips"])
|
|
}
|
|
|
|
if int(stats["total_cidrs"].(float64)) != 1 {
|
|
t.Errorf("total_cidrs = %v, want 1", stats["total_cidrs"])
|
|
}
|
|
}
|
|
|
|
func TestIPBlacklist_GetBlacklist(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
bl.BlockIP("192.168.1.100", "")
|
|
bl.BlockIP("192.168.1.101", "")
|
|
bl.BlockCIDR("10.0.0.0/24", "")
|
|
|
|
ips, cidrs := bl.GetBlacklist()
|
|
|
|
if len(ips) != 2 {
|
|
t.Errorf("len(ips) = %d, want 2", len(ips))
|
|
}
|
|
|
|
if len(cidrs) != 1 {
|
|
t.Errorf("len(cidrs) = %d, want 1", len(cidrs))
|
|
}
|
|
|
|
// Verify CIDR format
|
|
if cidrs[0] != "10.0.0.0/24" {
|
|
t.Errorf("CIDR = %q, want %q", cidrs[0], "10.0.0.0/24")
|
|
}
|
|
}
|
|
|
|
func TestIPBlacklist_InvalidIP(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
|
|
err := bl.BlockIP("invalid-ip", "Test")
|
|
if err == nil {
|
|
t.Error("BlockIP() should return error for invalid IP")
|
|
}
|
|
}
|
|
|
|
func TestIPBlacklist_InvalidCIDR(t *testing.T) {
|
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
|
|
|
err := bl.BlockCIDR("invalid-cidr", "Test")
|
|
if err == nil {
|
|
t.Error("BlockCIDR() should return error for invalid CIDR")
|
|
}
|
|
}
|