diff --git a/SECURITY_FEATURES.md b/SECURITY_FEATURES.md new file mode 100644 index 0000000..80ee674 --- /dev/null +++ b/SECURITY_FEATURES.md @@ -0,0 +1,440 @@ +# Security Features: Blacklist & Rate Limit Inspection + +## IP Blacklist + +The IP blacklist middleware allows you to block specific IP addresses or CIDR ranges from accessing your application. + +### Basic Usage + +```go +import "github.com/bitechdev/ResolveSpec/pkg/middleware" + +// Create blacklist (UseProxy=true if behind a proxy) +blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{ + UseProxy: true, // Checks X-Forwarded-For and X-Real-IP headers +}) + +// Block individual IP +blacklist.BlockIP("192.168.1.100", "Suspicious activity detected") + +// Block entire CIDR range +blacklist.BlockCIDR("10.0.0.0/8", "Private network blocked") + +// Apply middleware +http.Handle("/api/", blacklist.Middleware(yourHandler)) +``` + +### Managing Blacklist + +```go +// Unblock an IP +blacklist.UnblockIP("192.168.1.100") + +// Unblock a CIDR range +blacklist.UnblockCIDR("10.0.0.0/8") + +// Get all blacklisted IPs and CIDRs +ips, cidrs := blacklist.GetBlacklist() +fmt.Printf("Blocked IPs: %v\n", ips) +fmt.Printf("Blocked CIDRs: %v\n", cidrs) + +// Check if specific IP is blocked +blocked, reason := blacklist.IsBlocked("192.168.1.100") +if blocked { + fmt.Printf("IP blocked: %s\n", reason) +} +``` + +### Blacklist Statistics Endpoint + +Expose blacklist statistics via HTTP: + +```go +// Add stats endpoint +http.Handle("/admin/blacklist-stats", blacklist.StatsHandler()) +``` + +**Example Response:** +```json +{ + "blocked_ips": ["192.168.1.100", "192.168.1.101"], + "blocked_cidrs": ["10.0.0.0/8"], + "total_ips": 2, + "total_cidrs": 1 +} +``` + +### Integration Example + +```go +func main() { + // Create blacklist + blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{ + UseProxy: true, + }) + + // Block known malicious IPs + blacklist.BlockIP("203.0.113.1", "Known scanner") + blacklist.BlockCIDR("198.51.100.0/24", "Spam network") + + // Create your router + mux := http.NewServeMux() + + // Protected routes + mux.Handle("/api/", blacklist.Middleware(apiHandler)) + + // Admin endpoint to manage blacklist + mux.HandleFunc("/admin/block-ip", func(w http.ResponseWriter, r *http.Request) { + ip := r.URL.Query().Get("ip") + reason := r.URL.Query().Get("reason") + + if err := blacklist.BlockIP(ip, reason); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "Blocked %s: %s", ip, reason) + }) + + // Stats endpoint + mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler()) + + http.ListenAndServe(":8080", mux) +} +``` + +--- + +## Rate Limit Inspection + +Monitor and inspect rate limit status per IP address in real-time. + +### Basic Usage + +```go +import "github.com/bitechdev/ResolveSpec/pkg/middleware" + +// Create rate limiter (10 req/sec, burst of 20) +rateLimiter := middleware.NewRateLimiter(10, 20) + +// Apply middleware +http.Handle("/api/", rateLimiter.Middleware(yourHandler)) +``` + +### Programmatic Inspection + +```go +// Get all tracked IPs +trackedIPs := rateLimiter.GetTrackedIPs() +fmt.Printf("Currently tracking %d IPs\n", len(trackedIPs)) + +// Get rate limit info for specific IP +info := rateLimiter.GetRateLimitInfo("192.168.1.1") +fmt.Printf("IP: %s\n", info.IP) +fmt.Printf("Tokens Remaining: %.2f\n", info.TokensRemaining) +fmt.Printf("Limit: %.2f req/sec\n", info.Limit) +fmt.Printf("Burst: %d\n", info.Burst) + +// Get info for all tracked IPs +allInfo := rateLimiter.GetAllRateLimitInfo() +for _, info := range allInfo { + fmt.Printf("%s: %.2f tokens remaining\n", info.IP, info.TokensRemaining) +} +``` + +### Rate Limit Stats Endpoint + +Expose rate limit statistics via HTTP: + +```go +// Add stats endpoint +http.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler()) +``` + +**Example Response (all IPs):** +```json +{ + "total_tracked_ips": 3, + "rate_limit_config": { + "requests_per_second": 10, + "burst": 20 + }, + "tracked_ips": [ + { + "ip": "192.168.1.1", + "tokens_remaining": 15.5, + "limit": 10, + "burst": 20 + }, + { + "ip": "192.168.1.2", + "tokens_remaining": 18.2, + "limit": 10, + "burst": 20 + } + ] +} +``` + +**Example Response (specific IP):** +```bash +GET /admin/rate-limit-stats?ip=192.168.1.1 +``` +```json +{ + "ip": "192.168.1.1", + "tokens_remaining": 15.5, + "limit": 10, + "burst": 20 +} +``` + +### Complete Integration Example + +```go +package main + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/bitechdev/ResolveSpec/pkg/middleware" +) + +func main() { + // Create rate limiter + rateLimiter := middleware.NewRateLimiter(10, 20) + + // Create blacklist + blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{ + UseProxy: true, + }) + + mux := http.NewServeMux() + + // API handler with both middlewares (blacklist first, then rate limit) + apiHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "message": "Success", + }) + }) + + // Apply middleware chain: blacklist -> rate limit -> handler + mux.Handle("/api/", blacklist.Middleware(rateLimiter.Middleware(apiHandler))) + + // Admin endpoints + mux.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler()) + mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler()) + + // Custom monitoring endpoint + mux.HandleFunc("/admin/monitor", func(w http.ResponseWriter, r *http.Request) { + // Get rate limit stats + rateLimitInfo := rateLimiter.GetAllRateLimitInfo() + + // Get blacklist stats + blockedIPs, blockedCIDRs := blacklist.GetBlacklist() + + response := map[string]interface{}{ + "rate_limits": rateLimitInfo, + "blacklist": map[string]interface{}{ + "ips": blockedIPs, + "cidrs": blockedCIDRs, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + }) + + // Dynamic blacklist management + mux.HandleFunc("/admin/block", func(w http.ResponseWriter, r *http.Request) { + ip := r.URL.Query().Get("ip") + reason := r.URL.Query().Get("reason") + + if ip == "" { + http.Error(w, "IP required", http.StatusBadRequest) + return + } + + if err := blacklist.BlockIP(ip, reason); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + fmt.Fprintf(w, "Blocked %s: %s", ip, reason) + }) + + mux.HandleFunc("/admin/unblock", func(w http.ResponseWriter, r *http.Request) { + ip := r.URL.Query().Get("ip") + if ip == "" { + http.Error(w, "IP required", http.StatusBadRequest) + return + } + + blacklist.UnblockIP(ip) + fmt.Fprintf(w, "Unblocked %s", ip) + }) + + // Auto-block IPs that exceed rate limit + mux.HandleFunc("/admin/auto-block-heavy-users", func(w http.ResponseWriter, r *http.Request) { + blocked := 0 + + for _, info := range rateLimiter.GetAllRateLimitInfo() { + // If tokens are very low, IP is making many requests + if info.TokensRemaining < 1.0 { + blacklist.BlockIP(info.IP, "Exceeded rate limit") + blocked++ + } + } + + fmt.Fprintf(w, "Blocked %d IPs exceeding rate limits", blocked) + }) + + fmt.Println("Server starting on :8080") + fmt.Println("Rate limit stats: http://localhost:8080/admin/rate-limit-stats") + fmt.Println("Blacklist stats: http://localhost:8080/admin/blacklist-stats") + http.ListenAndServe(":8080", mux) +} +``` + +--- + +## Monitoring Dashboard Example + +Create a simple monitoring page: + +```go +mux.HandleFunc("/admin/dashboard", func(w http.ResponseWriter, r *http.Request) { + html := ` + +
+Loading...+ +
Loading...+ + + ` + + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(html)) +}) +``` + +--- + +## Best Practices + +### 1. Proxy Configuration +Always set `UseProxy: true` when running behind a reverse proxy (nginx, Cloudflare, etc.): +```go +blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{ + UseProxy: true, // Checks X-Forwarded-For headers +}) +``` + +### 2. Middleware Order +Apply blacklist before rate limiting to save resources: +```go +// Correct order: blacklist -> rate limit -> handler +handler := blacklist.Middleware( + rateLimiter.Middleware(yourHandler) +) +``` + +### 3. Secure Admin Endpoints +Protect admin endpoints with authentication: +```go +mux.Handle("/admin/", authMiddleware(adminHandler)) +``` + +### 4. Monitoring +Set up alerts when: +- Many IPs are being rate limited +- Blacklist grows too large +- Specific IPs are repeatedly blocked + +### 5. Dynamic Response +Automatically block IPs that consistently exceed rate limits: +```go +// Check every minute +ticker := time.NewTicker(1 * time.Minute) +go func() { + for range ticker.C { + for _, info := range rateLimiter.GetAllRateLimitInfo() { + if info.TokensRemaining < 0.5 { + blacklist.BlockIP(info.IP, "Automated block: rate limit exceeded") + } + } + } +}() +``` + +### 6. CIDR for Network Blocks +Use CIDR ranges to block entire networks efficiently: +```go +// Block entire subnets +blacklist.BlockCIDR("10.0.0.0/8", "Private network") +blacklist.BlockCIDR("192.168.0.0/16", "Local network") +``` + +--- + +## API Reference + +### IPBlacklist + +#### Methods +- `BlockIP(ip, reason string) error` - Block a single IP address +- `BlockCIDR(cidr, reason string) error` - Block a CIDR range +- `UnblockIP(ip string)` - Remove IP from blacklist +- `UnblockCIDR(cidr string)` - Remove CIDR from blacklist +- `IsBlocked(ip string) (blocked bool, reason string)` - Check if IP is blocked +- `GetBlacklist() (ips, cidrs []string)` - Get all blocked IPs and CIDRs +- `Middleware(next http.Handler) http.Handler` - HTTP middleware +- `StatsHandler() http.Handler` - HTTP handler for statistics + +### RateLimiter + +#### Methods +- `GetTrackedIPs() []string` - Get all tracked IP addresses +- `GetRateLimitInfo(ip string) *RateLimitInfo` - Get info for specific IP +- `GetAllRateLimitInfo() []*RateLimitInfo` - Get info for all tracked IPs +- `Middleware(next http.Handler) http.Handler` - HTTP middleware +- `StatsHandler() http.Handler` - HTTP handler for statistics + +#### RateLimitInfo Structure +```go +type RateLimitInfo struct { + IP string `json:"ip"` + TokensRemaining float64 `json:"tokens_remaining"` + Limit float64 `json:"limit"` + Burst int `json:"burst"` +} +``` diff --git a/pkg/middleware/blacklist.go b/pkg/middleware/blacklist.go new file mode 100644 index 0000000..274f4b7 --- /dev/null +++ b/pkg/middleware/blacklist.go @@ -0,0 +1,204 @@ +package middleware + +import ( + "encoding/json" + "net" + "net/http" + "strings" + "sync" +) + +// IPBlacklist provides IP blocking functionality +type IPBlacklist struct { + mu sync.RWMutex + ips map[string]bool // Individual IPs + cidrs []*net.IPNet // CIDR ranges + reason map[string]string + useProxy bool // Whether to check X-Forwarded-For headers +} + +// BlacklistConfig configures the IP blacklist +type BlacklistConfig struct { + // UseProxy indicates whether to extract IP from X-Forwarded-For/X-Real-IP headers + UseProxy bool +} + +// NewIPBlacklist creates a new IP blacklist +func NewIPBlacklist(config BlacklistConfig) *IPBlacklist { + return &IPBlacklist{ + ips: make(map[string]bool), + cidrs: make([]*net.IPNet, 0), + reason: make(map[string]string), + useProxy: config.UseProxy, + } +} + +// BlockIP blocks a single IP address +func (bl *IPBlacklist) BlockIP(ip string, reason string) error { + // Validate IP + if net.ParseIP(ip) == nil { + return &net.ParseError{Type: "IP address", Text: ip} + } + + bl.mu.Lock() + defer bl.mu.Unlock() + + bl.ips[ip] = true + if reason != "" { + bl.reason[ip] = reason + } + return nil +} + +// BlockCIDR blocks an IP range using CIDR notation +func (bl *IPBlacklist) BlockCIDR(cidr string, reason string) error { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + + bl.mu.Lock() + defer bl.mu.Unlock() + + bl.cidrs = append(bl.cidrs, ipNet) + if reason != "" { + bl.reason[cidr] = reason + } + return nil +} + +// UnblockIP removes an IP from the blacklist +func (bl *IPBlacklist) UnblockIP(ip string) { + bl.mu.Lock() + defer bl.mu.Unlock() + + delete(bl.ips, ip) + delete(bl.reason, ip) +} + +// UnblockCIDR removes a CIDR range from the blacklist +func (bl *IPBlacklist) UnblockCIDR(cidr string) { + bl.mu.Lock() + defer bl.mu.Unlock() + + // Find and remove the CIDR + for i, ipNet := range bl.cidrs { + if ipNet.String() == cidr { + bl.cidrs = append(bl.cidrs[:i], bl.cidrs[i+1:]...) + break + } + } + delete(bl.reason, cidr) +} + +// IsBlocked checks if an IP is blacklisted +func (bl *IPBlacklist) IsBlocked(ip string) (bool, string) { + bl.mu.RLock() + defer bl.mu.RUnlock() + + // Check individual IPs + if bl.ips[ip] { + return true, bl.reason[ip] + } + + // Check CIDR ranges + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false, "" + } + + for i, ipNet := range bl.cidrs { + if ipNet.Contains(parsedIP) { + cidr := ipNet.String() + // Try to find reason by CIDR or by index + if reason, ok := bl.reason[cidr]; ok { + return true, reason + } + // Check if reason was stored by original CIDR string + for key, reason := range bl.reason { + if strings.Contains(key, "/") && key == cidr { + return true, reason + } + } + // Return true even if no reason found + if i < len(bl.cidrs) { + return true, "" + } + } + } + + return false, "" +} + +// GetBlacklist returns all blacklisted IPs and CIDRs +func (bl *IPBlacklist) GetBlacklist() (ips []string, cidrs []string) { + bl.mu.RLock() + defer bl.mu.RUnlock() + + ips = make([]string, 0, len(bl.ips)) + for ip := range bl.ips { + ips = append(ips, ip) + } + + cidrs = make([]string, 0, len(bl.cidrs)) + for _, ipNet := range bl.cidrs { + cidrs = append(cidrs, ipNet.String()) + } + + return ips, cidrs +} + +// Middleware returns an HTTP middleware that blocks blacklisted IPs +func (bl *IPBlacklist) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var clientIP string + if bl.useProxy { + clientIP = getClientIP(r) + // Clean up IPv6 brackets if present + clientIP = strings.Trim(clientIP, "[]") + } else { + // Extract IP from RemoteAddr + if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 { + clientIP = r.RemoteAddr[:idx] + } else { + clientIP = r.RemoteAddr + } + clientIP = strings.Trim(clientIP, "[]") + } + + blocked, reason := bl.IsBlocked(clientIP) + if blocked { + response := map[string]interface{}{ + "error": "forbidden", + "message": "Access denied", + } + if reason != "" { + response["reason"] = reason + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(response) + return + } + + next.ServeHTTP(w, r) + }) +} + +// StatsHandler returns an HTTP handler that shows blacklist statistics +func (bl *IPBlacklist) StatsHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ips, cidrs := bl.GetBlacklist() + + stats := map[string]interface{}{ + "blocked_ips": ips, + "blocked_cidrs": cidrs, + "total_ips": len(ips), + "total_cidrs": len(cidrs), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(stats) + }) +} diff --git a/pkg/middleware/blacklist_test.go b/pkg/middleware/blacklist_test.go new file mode 100644 index 0000000..eead24e --- /dev/null +++ b/pkg/middleware/blacklist_test.go @@ -0,0 +1,254 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestIPBlacklist_BlockIP(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + + // Block an IP + err := bl.BlockIP("192.168.1.100", "Suspicious activity") + if err != nil { + t.Fatalf("BlockIP() error = %v", err) + } + + // Check if IP is blocked + blocked, reason := bl.IsBlocked("192.168.1.100") + if !blocked { + t.Error("IP should be blocked") + } + if reason != "Suspicious activity" { + t.Errorf("Reason = %q, want %q", reason, "Suspicious activity") + } + + // Check non-blocked IP + blocked, _ = bl.IsBlocked("192.168.1.1") + if blocked { + t.Error("IP should not be blocked") + } +} + +func TestIPBlacklist_BlockCIDR(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + + // Block a CIDR range + err := bl.BlockCIDR("10.0.0.0/24", "Internal network blocked") + if err != nil { + t.Fatalf("BlockCIDR() error = %v", err) + } + + // Check IPs in range + testIPs := []string{ + "10.0.0.1", + "10.0.0.100", + "10.0.0.254", + } + + for _, ip := range testIPs { + blocked, _ := bl.IsBlocked(ip) + if !blocked { + t.Errorf("IP %s should be blocked by CIDR", ip) + } + } + + // Check IP outside range + blocked, _ := bl.IsBlocked("10.0.1.1") + if blocked { + t.Error("IP outside CIDR range should not be blocked") + } +} + +func TestIPBlacklist_UnblockIP(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + + // Block and then unblock + bl.BlockIP("192.168.1.100", "Test") + + blocked, _ := bl.IsBlocked("192.168.1.100") + if !blocked { + t.Error("IP should be blocked") + } + + bl.UnblockIP("192.168.1.100") + + blocked, _ = bl.IsBlocked("192.168.1.100") + if blocked { + t.Error("IP should be unblocked") + } +} + +func TestIPBlacklist_UnblockCIDR(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + + // Block and then unblock CIDR + bl.BlockCIDR("10.0.0.0/24", "Test") + + blocked, _ := bl.IsBlocked("10.0.0.1") + if !blocked { + t.Error("IP should be blocked by CIDR") + } + + bl.UnblockCIDR("10.0.0.0/24") + + blocked, _ = bl.IsBlocked("10.0.0.1") + if blocked { + t.Error("IP should be unblocked after CIDR removal") + } +} + +func TestIPBlacklist_Middleware(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + bl.BlockIP("192.168.1.100", "Banned") + + handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + + // Blocked IP should get 403 + t.Run("BlockedIP", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if response["error"] != "forbidden" { + t.Errorf("Error = %v, want %q", response["error"], "forbidden") + } + }) + + // Allowed IP should succeed + t.Run("AllowedIP", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", w.Code, http.StatusOK) + } + }) +} + +func TestIPBlacklist_MiddlewareWithProxy(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: true}) + bl.BlockIP("203.0.113.1", "Blocked via proxy") + + handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Test X-Forwarded-For + t.Run("X-Forwarded-For", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:12345" + req.Header.Set("X-Forwarded-For", "203.0.113.1") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden) + } + }) + + // Test X-Real-IP + t.Run("X-Real-IP", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:12345" + req.Header.Set("X-Real-IP", "203.0.113.1") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestIPBlacklist_StatsHandler(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + bl.BlockIP("192.168.1.100", "Test1") + bl.BlockIP("192.168.1.101", "Test2") + bl.BlockCIDR("10.0.0.0/24", "Test CIDR") + + handler := bl.StatsHandler() + + req := httptest.NewRequest("GET", "/blacklist-stats", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", w.Code, http.StatusOK) + } + + var stats map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if int(stats["total_ips"].(float64)) != 2 { + t.Errorf("total_ips = %v, want 2", stats["total_ips"]) + } + + if int(stats["total_cidrs"].(float64)) != 1 { + t.Errorf("total_cidrs = %v, want 1", stats["total_cidrs"]) + } +} + +func TestIPBlacklist_GetBlacklist(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + bl.BlockIP("192.168.1.100", "") + bl.BlockIP("192.168.1.101", "") + bl.BlockCIDR("10.0.0.0/24", "") + + ips, cidrs := bl.GetBlacklist() + + if len(ips) != 2 { + t.Errorf("len(ips) = %d, want 2", len(ips)) + } + + if len(cidrs) != 1 { + t.Errorf("len(cidrs) = %d, want 1", len(cidrs)) + } + + // Verify CIDR format + if cidrs[0] != "10.0.0.0/24" { + t.Errorf("CIDR = %q, want %q", cidrs[0], "10.0.0.0/24") + } +} + +func TestIPBlacklist_InvalidIP(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + + err := bl.BlockIP("invalid-ip", "Test") + if err == nil { + t.Error("BlockIP() should return error for invalid IP") + } +} + +func TestIPBlacklist_InvalidCIDR(t *testing.T) { + bl := NewIPBlacklist(BlacklistConfig{UseProxy: false}) + + err := bl.BlockCIDR("invalid-cidr", "Test") + if err == nil { + t.Error("BlockCIDR() should return error for invalid CIDR") + } +} diff --git a/pkg/middleware/ratelimit.go b/pkg/middleware/ratelimit.go index debc2f2..ac10314 100644 --- a/pkg/middleware/ratelimit.go +++ b/pkg/middleware/ratelimit.go @@ -1,7 +1,9 @@ package middleware import ( + "encoding/json" "net/http" + "strings" "sync" "time" @@ -71,11 +73,11 @@ func (rl *RateLimiter) cleanupRoutine() { } // Middleware returns an HTTP middleware that applies rate limiting +// Automatically handles X-Forwarded-For headers when behind a proxy func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Use IP address as the rate limit key - // In production, you might want to use X-Forwarded-For or custom headers - key := r.RemoteAddr + // Extract client IP, handling proxy headers + key := getClientIP(r) limiter := rl.getLimiter(key) @@ -108,3 +110,115 @@ func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string) }) } } + +// RateLimitInfo contains information about a specific IP's rate limit status +type RateLimitInfo struct { + IP string `json:"ip"` + TokensRemaining float64 `json:"tokens_remaining"` + Limit float64 `json:"limit"` + Burst int `json:"burst"` +} + +// GetTrackedIPs returns all IPs currently being tracked by the rate limiter +func (rl *RateLimiter) GetTrackedIPs() []string { + rl.mu.RLock() + defer rl.mu.RUnlock() + + ips := make([]string, 0, len(rl.limiters)) + for ip := range rl.limiters { + ips = append(ips, ip) + } + return ips +} + +// GetRateLimitInfo returns rate limit information for a specific IP +func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo { + rl.mu.RLock() + limiter, exists := rl.limiters[ip] + rl.mu.RUnlock() + + if !exists { + // Return default info for untracked IP + return &RateLimitInfo{ + IP: ip, + TokensRemaining: float64(rl.burst), + Limit: float64(rl.rate), + Burst: rl.burst, + } + } + + return &RateLimitInfo{ + IP: ip, + TokensRemaining: limiter.Tokens(), + Limit: float64(rl.rate), + Burst: rl.burst, + } +} + +// GetAllRateLimitInfo returns rate limit information for all tracked IPs +func (rl *RateLimiter) GetAllRateLimitInfo() []*RateLimitInfo { + ips := rl.GetTrackedIPs() + info := make([]*RateLimitInfo, 0, len(ips)) + + for _, ip := range ips { + info = append(info, rl.GetRateLimitInfo(ip)) + } + + return info +} + +// StatsHandler returns an HTTP handler that exposes rate limit statistics +// Example: GET /rate-limit-stats +func (rl *RateLimiter) StatsHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Support querying specific IP via ?ip=x.x.x.x + if ip := r.URL.Query().Get("ip"); ip != "" { + info := rl.GetRateLimitInfo(ip) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(info) + return + } + + // Return all tracked IPs + allInfo := rl.GetAllRateLimitInfo() + + stats := map[string]interface{}{ + "total_tracked_ips": len(allInfo), + "rate_limit_config": map[string]interface{}{ + "requests_per_second": float64(rl.rate), + "burst": rl.burst, + }, + "tracked_ips": allInfo, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(stats) + }) +} + +// getClientIP extracts the real client IP from the request +// Handles X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (most common in production) + // Format: X-Forwarded-For: client, proxy1, proxy2 + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP (the original client) + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + + // Check X-Real-IP header (used by some proxies like nginx) + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + + // Fall back to RemoteAddr + // Remove port if present (format: "ip:port") + if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 { + return r.RemoteAddr[:idx] + } + + return r.RemoteAddr +} diff --git a/pkg/middleware/ratelimit_test.go b/pkg/middleware/ratelimit_test.go new file mode 100644 index 0000000..2e58f2d --- /dev/null +++ b/pkg/middleware/ratelimit_test.go @@ -0,0 +1,388 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRateLimiter(t *testing.T) { + // Create rate limiter: 2 requests per second, burst of 2 + rl := NewRateLimiter(2, 2) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + + // First request should succeed + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("First request failed: got %d, want %d", w.Code, http.StatusOK) + } + + // Second request should succeed (within burst) + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Second request failed: got %d, want %d", w.Code, http.StatusOK) + } + + // Third request should be rate limited + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("Third request should be rate limited: got %d, want %d", w.Code, http.StatusTooManyRequests) + } + + // Wait for rate limiter to refill + time.Sleep(600 * time.Millisecond) + + // Request should succeed again + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Request after wait failed: got %d, want %d", w.Code, http.StatusOK) + } +} + +func TestRateLimiterDifferentIPs(t *testing.T) { + rl := NewRateLimiter(1, 1) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First IP + req1 := httptest.NewRequest("GET", "/test", nil) + req1.RemoteAddr = "192.168.1.1:12345" + + // Second IP + req2 := httptest.NewRequest("GET", "/test", nil) + req2.RemoteAddr = "192.168.1.2:12345" + + // Both should succeed (different IPs) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + if w1.Code != http.StatusOK { + t.Errorf("First IP request failed: got %d, want %d", w1.Code, http.StatusOK) + } + + if w2.Code != http.StatusOK { + t.Errorf("Second IP request failed: got %d, want %d", w2.Code, http.StatusOK) + } +} + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + remoteAddr string + xForwardedFor string + xRealIP string + expectedIP string + }{ + { + name: "RemoteAddr only", + remoteAddr: "192.168.1.1:12345", + expectedIP: "192.168.1.1", + }, + { + name: "X-Forwarded-For single IP", + remoteAddr: "10.0.0.1:12345", + xForwardedFor: "203.0.113.1", + expectedIP: "203.0.113.1", + }, + { + name: "X-Forwarded-For multiple IPs", + remoteAddr: "10.0.0.1:12345", + xForwardedFor: "203.0.113.1, 10.0.0.2, 10.0.0.3", + expectedIP: "203.0.113.1", + }, + { + name: "X-Real-IP", + remoteAddr: "10.0.0.1:12345", + xRealIP: "203.0.113.1", + expectedIP: "203.0.113.1", + }, + { + name: "X-Forwarded-For takes precedence over X-Real-IP", + remoteAddr: "10.0.0.1:12345", + xForwardedFor: "203.0.113.1", + xRealIP: "203.0.113.2", + expectedIP: "203.0.113.1", + }, + { + name: "IPv6 address", + remoteAddr: "[2001:db8::1]:12345", + expectedIP: "[2001:db8::1]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = tt.remoteAddr + + if tt.xForwardedFor != "" { + req.Header.Set("X-Forwarded-For", tt.xForwardedFor) + } + if tt.xRealIP != "" { + req.Header.Set("X-Real-IP", tt.xRealIP) + } + + ip := getClientIP(req) + if ip != tt.expectedIP { + t.Errorf("getClientIP() = %q, want %q", ip, tt.expectedIP) + } + }) + } +} + +func TestRateLimiterWithCustomKeyFunc(t *testing.T) { + rl := NewRateLimiter(1, 1) + + // Use user ID as key + keyFunc := func(r *http.Request) string { + userID := r.Header.Get("X-User-ID") + if userID == "" { + return r.RemoteAddr + } + return "user:" + userID + } + + handler := rl.MiddlewareWithKeyFunc(keyFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // User 1 + req1 := httptest.NewRequest("GET", "/test", nil) + req1.Header.Set("X-User-ID", "user1") + + // User 2 + req2 := httptest.NewRequest("GET", "/test", nil) + req2.Header.Set("X-User-ID", "user2") + + // Both users should succeed (different keys) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + if w1.Code != http.StatusOK { + t.Errorf("User 1 request failed: got %d, want %d", w1.Code, http.StatusOK) + } + + if w2.Code != http.StatusOK { + t.Errorf("User 2 request failed: got %d, want %d", w2.Code, http.StatusOK) + } + + // User 1 second request should be rate limited + w1 = httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + if w1.Code != http.StatusTooManyRequests { + t.Errorf("User 1 second request should be rate limited: got %d, want %d", w1.Code, http.StatusTooManyRequests) + } +} + +func TestRateLimiter_GetTrackedIPs(t *testing.T) { + rl := NewRateLimiter(10, 10) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make requests from different IPs + ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"} + for _, ip := range ips { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = ip + ":12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + } + + // Check tracked IPs + trackedIPs := rl.GetTrackedIPs() + if len(trackedIPs) != len(ips) { + t.Errorf("len(trackedIPs) = %d, want %d", len(trackedIPs), len(ips)) + } + + // Verify all IPs are tracked + ipMap := make(map[string]bool) + for _, ip := range trackedIPs { + ipMap[ip] = true + } + + for _, ip := range ips { + if !ipMap[ip] { + t.Errorf("IP %s should be tracked", ip) + } + } +} + +func TestRateLimiter_GetRateLimitInfo(t *testing.T) { + rl := NewRateLimiter(10, 5) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make a request + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + // Get rate limit info + info := rl.GetRateLimitInfo("192.168.1.1") + + if info.IP != "192.168.1.1" { + t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1") + } + + if info.Limit != 10.0 { + t.Errorf("Limit = %f, want 10.0", info.Limit) + } + + if info.Burst != 5 { + t.Errorf("Burst = %d, want 5", info.Burst) + } + + // Tokens should be less than burst after one request + if info.TokensRemaining >= float64(info.Burst) { + t.Errorf("TokensRemaining = %f, should be less than %d", info.TokensRemaining, info.Burst) + } +} + +func TestRateLimiter_GetRateLimitInfo_UntrackedIP(t *testing.T) { + rl := NewRateLimiter(10, 5) + + // Get info for untracked IP (should return default) + info := rl.GetRateLimitInfo("192.168.1.1") + + if info.IP != "192.168.1.1" { + t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1") + } + + if info.TokensRemaining != float64(rl.burst) { + t.Errorf("TokensRemaining = %f, want %d (full burst)", info.TokensRemaining, rl.burst) + } +} + +func TestRateLimiter_GetAllRateLimitInfo(t *testing.T) { + rl := NewRateLimiter(10, 10) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make requests from different IPs + ips := []string{"192.168.1.1", "192.168.1.2"} + for _, ip := range ips { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = ip + ":12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + } + + // Get all rate limit info + allInfo := rl.GetAllRateLimitInfo() + + if len(allInfo) != len(ips) { + t.Errorf("len(allInfo) = %d, want %d", len(allInfo), len(ips)) + } + + // Verify each IP has info + for _, info := range allInfo { + found := false + for _, ip := range ips { + if info.IP == ip { + found = true + break + } + } + if !found { + t.Errorf("Unexpected IP in info: %s", info.IP) + } + } +} + +func TestRateLimiter_StatsHandler(t *testing.T) { + rl := NewRateLimiter(10, 5) + + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make requests from different IPs + req1 := httptest.NewRequest("GET", "/test", nil) + req1.RemoteAddr = "192.168.1.1:12345" + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + req2 := httptest.NewRequest("GET", "/test", nil) + req2.RemoteAddr = "192.168.1.2:12345" + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + // Test stats handler (all IPs) + t.Run("AllIPs", func(t *testing.T) { + statsHandler := rl.StatsHandler() + req := httptest.NewRequest("GET", "/rate-limit-stats", nil) + w := httptest.NewRecorder() + statsHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", w.Code, http.StatusOK) + } + + var stats map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if int(stats["total_tracked_ips"].(float64)) != 2 { + t.Errorf("total_tracked_ips = %v, want 2", stats["total_tracked_ips"]) + } + + config := stats["rate_limit_config"].(map[string]interface{}) + if config["requests_per_second"].(float64) != 10.0 { + t.Errorf("requests_per_second = %v, want 10.0", config["requests_per_second"]) + } + }) + + // Test stats handler (specific IP) + t.Run("SpecificIP", func(t *testing.T) { + statsHandler := rl.StatsHandler() + req := httptest.NewRequest("GET", "/rate-limit-stats?ip=192.168.1.1", nil) + w := httptest.NewRecorder() + statsHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", w.Code, http.StatusOK) + } + + var info RateLimitInfo + if err := json.Unmarshal(w.Body.Bytes(), &info); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if info.IP != "192.168.1.1" { + t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1") + } + }) +} diff --git a/pkg/middleware/sanitize_test.go b/pkg/middleware/sanitize_test.go new file mode 100644 index 0000000..fc082f9 --- /dev/null +++ b/pkg/middleware/sanitize_test.go @@ -0,0 +1,273 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestSanitizeXSS(t *testing.T) { + sanitizer := DefaultSanitizer() + + tests := []struct { + name string + input string + contains string // String that should NOT be in output + }{ + { + name: "Script tag", + input: "", + contains: "John", + "email": "test@example.com", + "nested": map[string]interface{}{ + "bio": "", + }, + } + + result := sanitizer.SanitizeMap(input) + + // Check that script tag was removed/escaped + name, ok := result["name"].(string) + if !ok || name == input["name"] { + t.Error("Name should be sanitized") + } + + // Check nested map + nested, ok := result["nested"].(map[string]interface{}) + if !ok { + t.Fatal("Nested should still be a map") + } + + bio, ok := nested["bio"].(string) + if !ok || bio == input["nested"].(map[string]interface{})["bio"] { + t.Error("Nested bio should be sanitized") + } +} + +func TestSanitizeMiddleware(t *testing.T) { + sanitizer := DefaultSanitizer() + + handler := sanitizer.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that query param was sanitized + param := r.URL.Query().Get("q") + if param == "" { + t.Error("Query param should be sanitized") + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test?q=", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Handler failed: got %d, want %d", w.Code, http.StatusOK) + } +} + +func TestSanitizeFilename(t *testing.T) { + tests := []struct { + name string + input string + contains string // String that should NOT be in output + }{ + { + name: "Path traversal", + input: "../../../etc/passwd", + contains: "..", + }, + { + name: "Absolute path", + input: "/etc/passwd", + contains: "/", + }, + { + name: "Windows path", + input: "..\\..\\windows\\system32", + contains: "\\", + }, + { + name: "Null byte", + input: "file\x00.txt", + contains: "\x00", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeFilename(tt.input) + if result == tt.input { + t.Errorf("SanitizeFilename() did not modify input: %q", tt.input) + } + }) + } +} + +func TestSanitizeEmail(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "Uppercase", + input: "TEST@EXAMPLE.COM", + expected: "test@example.com", + }, + { + name: "Whitespace", + input: " test@example.com ", + expected: "test@example.com", + }, + { + name: "Null bytes", + input: "test\x00@example.com", + expected: "test@example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeEmail(tt.input) + if result != tt.expected { + t.Errorf("SanitizeEmail() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestSanitizeURL(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "JavaScript protocol", + input: "javascript:alert(1)", + expected: "", + }, + { + name: "Data protocol", + input: "data:text/html,", + expected: "", + }, + { + name: "Valid HTTP URL", + input: "https://example.com", + expected: "https://example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeURL(tt.input) + if result != tt.expected { + t.Errorf("SanitizeURL() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestStrictSanitizer(t *testing.T) { + sanitizer := StrictSanitizer() + + input := "Bold text with " + result := sanitizer.Sanitize(input) + + // Should strip ALL HTML tags + if result == input { + t.Error("Strict sanitizer should modify input") + } + + // Should not contain any HTML tags + if len(result) > 0 && (result[0] == '<' || result[len(result)-1] == '>') { + t.Error("Result should not contain HTML tags") + } +} + +func TestMaxStringLength(t *testing.T) { + sanitizer := &Sanitizer{ + MaxStringLength: 10, + } + + input := "This is a very long string that exceeds the maximum length" + result := sanitizer.Sanitize(input) + + if len(result) != 10 { + t.Errorf("Result length = %d, want 10", len(result)) + } + + if result != input[:10] { + t.Errorf("Result = %q, want %q", result, input[:10]) + } +} diff --git a/pkg/middleware/sizelimit_test.go b/pkg/middleware/sizelimit_test.go new file mode 100644 index 0000000..1d56415 --- /dev/null +++ b/pkg/middleware/sizelimit_test.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestSizeLimiter(t *testing.T) { + // 1KB limit + limiter := NewRequestSizeLimiter(1024) + + handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Try to read body + _, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusRequestEntityTooLarge) + return + } + w.WriteHeader(http.StatusOK) + })) + + // Small request (should succeed) + t.Run("SmallRequest", func(t *testing.T) { + body := bytes.NewReader(make([]byte, 512)) // 512 bytes + req := httptest.NewRequest("POST", "/test", body) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Small request failed: got %d, want %d", w.Code, http.StatusOK) + } + + // Check header + if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "1024" { + t.Errorf("MaxRequestSizeHeader = %q, want %q", maxSize, "1024") + } + }) + + // Large request (should fail) + t.Run("LargeRequest", func(t *testing.T) { + body := bytes.NewReader(make([]byte, 2048)) // 2KB + req := httptest.NewRequest("POST", "/test", body) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("Large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge) + } + }) +} + +func TestRequestSizeLimiterDefault(t *testing.T) { + // Default limiter (10MB) + limiter := NewRequestSizeLimiter(0) + + handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/test", bytes.NewReader(make([]byte, 1024))) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Request failed: got %d, want %d", w.Code, http.StatusOK) + } + + // Check default size + if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "10485760" { + t.Errorf("Default MaxRequestSizeHeader = %q, want %q", maxSize, "10485760") + } +} + +func TestRequestSizeLimiterWithCustomSize(t *testing.T) { + limiter := NewRequestSizeLimiter(1024) + + // Premium users get 10MB, regular users get 1KB + sizeFunc := func(r *http.Request) int64 { + if r.Header.Get("X-User-Tier") == "premium" { + return Size10MB + } + return 1024 + } + + handler := limiter.MiddlewareWithCustomSize(sizeFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusRequestEntityTooLarge) + return + } + w.WriteHeader(http.StatusOK) + })) + + // Regular user with large request (should fail) + t.Run("RegularUserLargeRequest", func(t *testing.T) { + body := bytes.NewReader(make([]byte, 2048)) + req := httptest.NewRequest("POST", "/test", body) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("Regular user large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge) + } + }) + + // Premium user with large request (should succeed) + t.Run("PremiumUserLargeRequest", func(t *testing.T) { + body := bytes.NewReader(make([]byte, 2048)) + req := httptest.NewRequest("POST", "/test", body) + req.Header.Set("X-User-Tier", "premium") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Premium user large request failed: got %d, want %d", w.Code, http.StatusOK) + } + }) +} diff --git a/pkg/server/shutdown_test.go b/pkg/server/shutdown_test.go new file mode 100644 index 0000000..8eef439 --- /dev/null +++ b/pkg/server/shutdown_test.go @@ -0,0 +1,231 @@ +package server + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func TestGracefulServerTrackRequests(t *testing.T) { + srv := NewGracefulServer(Config{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }), + }) + + handler := srv.TrackRequestsMiddleware(srv.server.Handler) + + // Start some requests + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + }() + } + + // Wait a bit for requests to start + time.Sleep(10 * time.Millisecond) + + // Check in-flight count + inFlight := srv.InFlightRequests() + if inFlight == 0 { + t.Error("Should have in-flight requests") + } + + // Wait for all requests to complete + wg.Wait() + + // Check that counter is back to zero + inFlight = srv.InFlightRequests() + if inFlight != 0 { + t.Errorf("In-flight requests should be 0, got %d", inFlight) + } +} + +func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) { + srv := NewGracefulServer(Config{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), + }) + + handler := srv.TrackRequestsMiddleware(srv.server.Handler) + + // Mark as shutting down + srv.isShuttingDown.Store(true) + + // Try to make a request + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + // Should get 503 + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Expected 503, got %d", w.Code) + } +} + +func TestHealthCheckHandler(t *testing.T) { + srv := NewGracefulServer(Config{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + }) + + handler := srv.HealthCheckHandler() + + // Healthy + t.Run("Healthy", func(t *testing.T) { + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + + if w.Body.String() != `{"status":"healthy"}` { + t.Errorf("Unexpected body: %s", w.Body.String()) + } + }) + + // Shutting down + t.Run("ShuttingDown", func(t *testing.T) { + srv.isShuttingDown.Store(true) + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Expected 503, got %d", w.Code) + } + }) +} + +func TestReadinessHandler(t *testing.T) { + srv := NewGracefulServer(Config{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + }) + + handler := srv.ReadinessHandler() + + // Ready with no in-flight requests + t.Run("Ready", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ready", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d", w.Code) + } + + body := w.Body.String() + if body != `{"ready":true,"in_flight_requests":0}` { + t.Errorf("Unexpected body: %s", body) + } + }) + + // Not ready during shutdown + t.Run("NotReady", func(t *testing.T) { + srv.isShuttingDown.Store(true) + + req := httptest.NewRequest("GET", "/ready", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Expected 503, got %d", w.Code) + } + }) +} + +func TestShutdownCallbacks(t *testing.T) { + callbackExecuted := false + + RegisterShutdownCallback(func(ctx context.Context) error { + callbackExecuted = true + return nil + }) + + ctx := context.Background() + err := executeShutdownCallbacks(ctx) + + if err != nil { + t.Errorf("executeShutdownCallbacks() error = %v", err) + } + + if !callbackExecuted { + t.Error("Shutdown callback was not executed") + } + + // Reset for other tests + shutdownCallbacks = nil +} + +func TestDrainRequests(t *testing.T) { + srv := NewGracefulServer(Config{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + DrainTimeout: 1 * time.Second, + }) + + // Simulate in-flight requests + srv.inFlightRequests.Add(3) + + // Start draining in background + go func() { + time.Sleep(100 * time.Millisecond) + // Simulate requests completing + srv.inFlightRequests.Add(-3) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := srv.drainRequests(ctx) + if err != nil { + t.Errorf("drainRequests() error = %v", err) + } + + if srv.InFlightRequests() != 0 { + t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests()) + } +} + +func TestDrainRequestsTimeout(t *testing.T) { + srv := NewGracefulServer(Config{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + DrainTimeout: 100 * time.Millisecond, + }) + + // Simulate in-flight requests that don't complete + srv.inFlightRequests.Add(5) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + err := srv.drainRequests(ctx) + if err == nil { + t.Error("drainRequests() should timeout with error") + } + + // Cleanup + srv.inFlightRequests.Add(-5) +} + +func TestGetClientIP(t *testing.T) { + // This test is in ratelimit_test.go since getClientIP is used by rate limiter + // Including here for completeness of server tests +}