mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
Added blacklist middleware
This commit is contained in:
parent
2a84652dba
commit
aeae9d7e0c
440
SECURITY_FEATURES.md
Normal file
440
SECURITY_FEATURES.md
Normal file
@ -0,0 +1,440 @@
|
||||
# Security Features: Blacklist & Rate Limit Inspection
|
||||
|
||||
## IP Blacklist
|
||||
|
||||
The IP blacklist middleware allows you to block specific IP addresses or CIDR ranges from accessing your application.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
|
||||
// Create blacklist (UseProxy=true if behind a proxy)
|
||||
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||
UseProxy: true, // Checks X-Forwarded-For and X-Real-IP headers
|
||||
})
|
||||
|
||||
// Block individual IP
|
||||
blacklist.BlockIP("192.168.1.100", "Suspicious activity detected")
|
||||
|
||||
// Block entire CIDR range
|
||||
blacklist.BlockCIDR("10.0.0.0/8", "Private network blocked")
|
||||
|
||||
// Apply middleware
|
||||
http.Handle("/api/", blacklist.Middleware(yourHandler))
|
||||
```
|
||||
|
||||
### Managing Blacklist
|
||||
|
||||
```go
|
||||
// Unblock an IP
|
||||
blacklist.UnblockIP("192.168.1.100")
|
||||
|
||||
// Unblock a CIDR range
|
||||
blacklist.UnblockCIDR("10.0.0.0/8")
|
||||
|
||||
// Get all blacklisted IPs and CIDRs
|
||||
ips, cidrs := blacklist.GetBlacklist()
|
||||
fmt.Printf("Blocked IPs: %v\n", ips)
|
||||
fmt.Printf("Blocked CIDRs: %v\n", cidrs)
|
||||
|
||||
// Check if specific IP is blocked
|
||||
blocked, reason := blacklist.IsBlocked("192.168.1.100")
|
||||
if blocked {
|
||||
fmt.Printf("IP blocked: %s\n", reason)
|
||||
}
|
||||
```
|
||||
|
||||
### Blacklist Statistics Endpoint
|
||||
|
||||
Expose blacklist statistics via HTTP:
|
||||
|
||||
```go
|
||||
// Add stats endpoint
|
||||
http.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
|
||||
```
|
||||
|
||||
**Example Response:**
|
||||
```json
|
||||
{
|
||||
"blocked_ips": ["192.168.1.100", "192.168.1.101"],
|
||||
"blocked_cidrs": ["10.0.0.0/8"],
|
||||
"total_ips": 2,
|
||||
"total_cidrs": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Example
|
||||
|
||||
```go
|
||||
func main() {
|
||||
// Create blacklist
|
||||
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||
UseProxy: true,
|
||||
})
|
||||
|
||||
// Block known malicious IPs
|
||||
blacklist.BlockIP("203.0.113.1", "Known scanner")
|
||||
blacklist.BlockCIDR("198.51.100.0/24", "Spam network")
|
||||
|
||||
// Create your router
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Protected routes
|
||||
mux.Handle("/api/", blacklist.Middleware(apiHandler))
|
||||
|
||||
// Admin endpoint to manage blacklist
|
||||
mux.HandleFunc("/admin/block-ip", func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := r.URL.Query().Get("ip")
|
||||
reason := r.URL.Query().Get("reason")
|
||||
|
||||
if err := blacklist.BlockIP(ip, reason); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "Blocked %s: %s", ip, reason)
|
||||
})
|
||||
|
||||
// Stats endpoint
|
||||
mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
|
||||
|
||||
http.ListenAndServe(":8080", mux)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Rate Limit Inspection
|
||||
|
||||
Monitor and inspect rate limit status per IP address in real-time.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
|
||||
// Create rate limiter (10 req/sec, burst of 20)
|
||||
rateLimiter := middleware.NewRateLimiter(10, 20)
|
||||
|
||||
// Apply middleware
|
||||
http.Handle("/api/", rateLimiter.Middleware(yourHandler))
|
||||
```
|
||||
|
||||
### Programmatic Inspection
|
||||
|
||||
```go
|
||||
// Get all tracked IPs
|
||||
trackedIPs := rateLimiter.GetTrackedIPs()
|
||||
fmt.Printf("Currently tracking %d IPs\n", len(trackedIPs))
|
||||
|
||||
// Get rate limit info for specific IP
|
||||
info := rateLimiter.GetRateLimitInfo("192.168.1.1")
|
||||
fmt.Printf("IP: %s\n", info.IP)
|
||||
fmt.Printf("Tokens Remaining: %.2f\n", info.TokensRemaining)
|
||||
fmt.Printf("Limit: %.2f req/sec\n", info.Limit)
|
||||
fmt.Printf("Burst: %d\n", info.Burst)
|
||||
|
||||
// Get info for all tracked IPs
|
||||
allInfo := rateLimiter.GetAllRateLimitInfo()
|
||||
for _, info := range allInfo {
|
||||
fmt.Printf("%s: %.2f tokens remaining\n", info.IP, info.TokensRemaining)
|
||||
}
|
||||
```
|
||||
|
||||
### Rate Limit Stats Endpoint
|
||||
|
||||
Expose rate limit statistics via HTTP:
|
||||
|
||||
```go
|
||||
// Add stats endpoint
|
||||
http.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler())
|
||||
```
|
||||
|
||||
**Example Response (all IPs):**
|
||||
```json
|
||||
{
|
||||
"total_tracked_ips": 3,
|
||||
"rate_limit_config": {
|
||||
"requests_per_second": 10,
|
||||
"burst": 20
|
||||
},
|
||||
"tracked_ips": [
|
||||
{
|
||||
"ip": "192.168.1.1",
|
||||
"tokens_remaining": 15.5,
|
||||
"limit": 10,
|
||||
"burst": 20
|
||||
},
|
||||
{
|
||||
"ip": "192.168.1.2",
|
||||
"tokens_remaining": 18.2,
|
||||
"limit": 10,
|
||||
"burst": 20
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Example Response (specific IP):**
|
||||
```bash
|
||||
GET /admin/rate-limit-stats?ip=192.168.1.1
|
||||
```
|
||||
```json
|
||||
{
|
||||
"ip": "192.168.1.1",
|
||||
"tokens_remaining": 15.5,
|
||||
"limit": 10,
|
||||
"burst": 20
|
||||
}
|
||||
```
|
||||
|
||||
### Complete Integration Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create rate limiter
|
||||
rateLimiter := middleware.NewRateLimiter(10, 20)
|
||||
|
||||
// Create blacklist
|
||||
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||
UseProxy: true,
|
||||
})
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// API handler with both middlewares (blacklist first, then rate limit)
|
||||
apiHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "Success",
|
||||
})
|
||||
})
|
||||
|
||||
// Apply middleware chain: blacklist -> rate limit -> handler
|
||||
mux.Handle("/api/", blacklist.Middleware(rateLimiter.Middleware(apiHandler)))
|
||||
|
||||
// Admin endpoints
|
||||
mux.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler())
|
||||
mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
|
||||
|
||||
// Custom monitoring endpoint
|
||||
mux.HandleFunc("/admin/monitor", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get rate limit stats
|
||||
rateLimitInfo := rateLimiter.GetAllRateLimitInfo()
|
||||
|
||||
// Get blacklist stats
|
||||
blockedIPs, blockedCIDRs := blacklist.GetBlacklist()
|
||||
|
||||
response := map[string]interface{}{
|
||||
"rate_limits": rateLimitInfo,
|
||||
"blacklist": map[string]interface{}{
|
||||
"ips": blockedIPs,
|
||||
"cidrs": blockedCIDRs,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
})
|
||||
|
||||
// Dynamic blacklist management
|
||||
mux.HandleFunc("/admin/block", func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := r.URL.Query().Get("ip")
|
||||
reason := r.URL.Query().Get("reason")
|
||||
|
||||
if ip == "" {
|
||||
http.Error(w, "IP required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := blacklist.BlockIP(ip, reason); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Blocked %s: %s", ip, reason)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/admin/unblock", func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := r.URL.Query().Get("ip")
|
||||
if ip == "" {
|
||||
http.Error(w, "IP required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
blacklist.UnblockIP(ip)
|
||||
fmt.Fprintf(w, "Unblocked %s", ip)
|
||||
})
|
||||
|
||||
// Auto-block IPs that exceed rate limit
|
||||
mux.HandleFunc("/admin/auto-block-heavy-users", func(w http.ResponseWriter, r *http.Request) {
|
||||
blocked := 0
|
||||
|
||||
for _, info := range rateLimiter.GetAllRateLimitInfo() {
|
||||
// If tokens are very low, IP is making many requests
|
||||
if info.TokensRemaining < 1.0 {
|
||||
blacklist.BlockIP(info.IP, "Exceeded rate limit")
|
||||
blocked++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Blocked %d IPs exceeding rate limits", blocked)
|
||||
})
|
||||
|
||||
fmt.Println("Server starting on :8080")
|
||||
fmt.Println("Rate limit stats: http://localhost:8080/admin/rate-limit-stats")
|
||||
fmt.Println("Blacklist stats: http://localhost:8080/admin/blacklist-stats")
|
||||
http.ListenAndServe(":8080", mux)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Monitoring Dashboard Example
|
||||
|
||||
Create a simple monitoring page:
|
||||
|
||||
```go
|
||||
mux.HandleFunc("/admin/dashboard", func(w http.ResponseWriter, r *http.Request) {
|
||||
html := `
|
||||
<html>
|
||||
<head>
|
||||
<title>Security Dashboard</title>
|
||||
<script>
|
||||
async function loadStats() {
|
||||
const rateLimitRes = await fetch('/admin/rate-limit-stats');
|
||||
const rateLimitData = await rateLimitRes.json();
|
||||
|
||||
const blacklistRes = await fetch('/admin/blacklist-stats');
|
||||
const blacklistData = await blacklistRes.json();
|
||||
|
||||
document.getElementById('rate-limit').innerHTML =
|
||||
JSON.stringify(rateLimitData, null, 2);
|
||||
document.getElementById('blacklist').innerHTML =
|
||||
JSON.stringify(blacklistData, null, 2);
|
||||
}
|
||||
|
||||
setInterval(loadStats, 5000); // Refresh every 5 seconds
|
||||
loadStats();
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Security Dashboard</h1>
|
||||
|
||||
<h2>Rate Limits</h2>
|
||||
<pre id="rate-limit">Loading...</pre>
|
||||
|
||||
<h2>Blacklist</h2>
|
||||
<pre id="blacklist">Loading...</pre>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte(html))
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Proxy Configuration
|
||||
Always set `UseProxy: true` when running behind a reverse proxy (nginx, Cloudflare, etc.):
|
||||
```go
|
||||
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||
UseProxy: true, // Checks X-Forwarded-For headers
|
||||
})
|
||||
```
|
||||
|
||||
### 2. Middleware Order
|
||||
Apply blacklist before rate limiting to save resources:
|
||||
```go
|
||||
// Correct order: blacklist -> rate limit -> handler
|
||||
handler := blacklist.Middleware(
|
||||
rateLimiter.Middleware(yourHandler)
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Secure Admin Endpoints
|
||||
Protect admin endpoints with authentication:
|
||||
```go
|
||||
mux.Handle("/admin/", authMiddleware(adminHandler))
|
||||
```
|
||||
|
||||
### 4. Monitoring
|
||||
Set up alerts when:
|
||||
- Many IPs are being rate limited
|
||||
- Blacklist grows too large
|
||||
- Specific IPs are repeatedly blocked
|
||||
|
||||
### 5. Dynamic Response
|
||||
Automatically block IPs that consistently exceed rate limits:
|
||||
```go
|
||||
// Check every minute
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
for _, info := range rateLimiter.GetAllRateLimitInfo() {
|
||||
if info.TokensRemaining < 0.5 {
|
||||
blacklist.BlockIP(info.IP, "Automated block: rate limit exceeded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
```
|
||||
|
||||
### 6. CIDR for Network Blocks
|
||||
Use CIDR ranges to block entire networks efficiently:
|
||||
```go
|
||||
// Block entire subnets
|
||||
blacklist.BlockCIDR("10.0.0.0/8", "Private network")
|
||||
blacklist.BlockCIDR("192.168.0.0/16", "Local network")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API Reference
|
||||
|
||||
### IPBlacklist
|
||||
|
||||
#### Methods
|
||||
- `BlockIP(ip, reason string) error` - Block a single IP address
|
||||
- `BlockCIDR(cidr, reason string) error` - Block a CIDR range
|
||||
- `UnblockIP(ip string)` - Remove IP from blacklist
|
||||
- `UnblockCIDR(cidr string)` - Remove CIDR from blacklist
|
||||
- `IsBlocked(ip string) (blocked bool, reason string)` - Check if IP is blocked
|
||||
- `GetBlacklist() (ips, cidrs []string)` - Get all blocked IPs and CIDRs
|
||||
- `Middleware(next http.Handler) http.Handler` - HTTP middleware
|
||||
- `StatsHandler() http.Handler` - HTTP handler for statistics
|
||||
|
||||
### RateLimiter
|
||||
|
||||
#### Methods
|
||||
- `GetTrackedIPs() []string` - Get all tracked IP addresses
|
||||
- `GetRateLimitInfo(ip string) *RateLimitInfo` - Get info for specific IP
|
||||
- `GetAllRateLimitInfo() []*RateLimitInfo` - Get info for all tracked IPs
|
||||
- `Middleware(next http.Handler) http.Handler` - HTTP middleware
|
||||
- `StatsHandler() http.Handler` - HTTP handler for statistics
|
||||
|
||||
#### RateLimitInfo Structure
|
||||
```go
|
||||
type RateLimitInfo struct {
|
||||
IP string `json:"ip"`
|
||||
TokensRemaining float64 `json:"tokens_remaining"`
|
||||
Limit float64 `json:"limit"`
|
||||
Burst int `json:"burst"`
|
||||
}
|
||||
```
|
||||
204
pkg/middleware/blacklist.go
Normal file
204
pkg/middleware/blacklist.go
Normal file
@ -0,0 +1,204 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// IPBlacklist provides IP blocking functionality
|
||||
type IPBlacklist struct {
|
||||
mu sync.RWMutex
|
||||
ips map[string]bool // Individual IPs
|
||||
cidrs []*net.IPNet // CIDR ranges
|
||||
reason map[string]string
|
||||
useProxy bool // Whether to check X-Forwarded-For headers
|
||||
}
|
||||
|
||||
// BlacklistConfig configures the IP blacklist
|
||||
type BlacklistConfig struct {
|
||||
// UseProxy indicates whether to extract IP from X-Forwarded-For/X-Real-IP headers
|
||||
UseProxy bool
|
||||
}
|
||||
|
||||
// NewIPBlacklist creates a new IP blacklist
|
||||
func NewIPBlacklist(config BlacklistConfig) *IPBlacklist {
|
||||
return &IPBlacklist{
|
||||
ips: make(map[string]bool),
|
||||
cidrs: make([]*net.IPNet, 0),
|
||||
reason: make(map[string]string),
|
||||
useProxy: config.UseProxy,
|
||||
}
|
||||
}
|
||||
|
||||
// BlockIP blocks a single IP address
|
||||
func (bl *IPBlacklist) BlockIP(ip string, reason string) error {
|
||||
// Validate IP
|
||||
if net.ParseIP(ip) == nil {
|
||||
return &net.ParseError{Type: "IP address", Text: ip}
|
||||
}
|
||||
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
bl.ips[ip] = true
|
||||
if reason != "" {
|
||||
bl.reason[ip] = reason
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockCIDR blocks an IP range using CIDR notation
|
||||
func (bl *IPBlacklist) BlockCIDR(cidr string, reason string) error {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
bl.cidrs = append(bl.cidrs, ipNet)
|
||||
if reason != "" {
|
||||
bl.reason[cidr] = reason
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnblockIP removes an IP from the blacklist
|
||||
func (bl *IPBlacklist) UnblockIP(ip string) {
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
delete(bl.ips, ip)
|
||||
delete(bl.reason, ip)
|
||||
}
|
||||
|
||||
// UnblockCIDR removes a CIDR range from the blacklist
|
||||
func (bl *IPBlacklist) UnblockCIDR(cidr string) {
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
// Find and remove the CIDR
|
||||
for i, ipNet := range bl.cidrs {
|
||||
if ipNet.String() == cidr {
|
||||
bl.cidrs = append(bl.cidrs[:i], bl.cidrs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
delete(bl.reason, cidr)
|
||||
}
|
||||
|
||||
// IsBlocked checks if an IP is blacklisted
|
||||
func (bl *IPBlacklist) IsBlocked(ip string) (bool, string) {
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Check individual IPs
|
||||
if bl.ips[ip] {
|
||||
return true, bl.reason[ip]
|
||||
}
|
||||
|
||||
// Check CIDR ranges
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
for i, ipNet := range bl.cidrs {
|
||||
if ipNet.Contains(parsedIP) {
|
||||
cidr := ipNet.String()
|
||||
// Try to find reason by CIDR or by index
|
||||
if reason, ok := bl.reason[cidr]; ok {
|
||||
return true, reason
|
||||
}
|
||||
// Check if reason was stored by original CIDR string
|
||||
for key, reason := range bl.reason {
|
||||
if strings.Contains(key, "/") && key == cidr {
|
||||
return true, reason
|
||||
}
|
||||
}
|
||||
// Return true even if no reason found
|
||||
if i < len(bl.cidrs) {
|
||||
return true, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// GetBlacklist returns all blacklisted IPs and CIDRs
|
||||
func (bl *IPBlacklist) GetBlacklist() (ips []string, cidrs []string) {
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
ips = make([]string, 0, len(bl.ips))
|
||||
for ip := range bl.ips {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
|
||||
cidrs = make([]string, 0, len(bl.cidrs))
|
||||
for _, ipNet := range bl.cidrs {
|
||||
cidrs = append(cidrs, ipNet.String())
|
||||
}
|
||||
|
||||
return ips, cidrs
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that blocks blacklisted IPs
|
||||
func (bl *IPBlacklist) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var clientIP string
|
||||
if bl.useProxy {
|
||||
clientIP = getClientIP(r)
|
||||
// Clean up IPv6 brackets if present
|
||||
clientIP = strings.Trim(clientIP, "[]")
|
||||
} else {
|
||||
// Extract IP from RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
clientIP = r.RemoteAddr[:idx]
|
||||
} else {
|
||||
clientIP = r.RemoteAddr
|
||||
}
|
||||
clientIP = strings.Trim(clientIP, "[]")
|
||||
}
|
||||
|
||||
blocked, reason := bl.IsBlocked(clientIP)
|
||||
if blocked {
|
||||
response := map[string]interface{}{
|
||||
"error": "forbidden",
|
||||
"message": "Access denied",
|
||||
}
|
||||
if reason != "" {
|
||||
response["reason"] = reason
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// StatsHandler returns an HTTP handler that shows blacklist statistics
|
||||
func (bl *IPBlacklist) StatsHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ips, cidrs := bl.GetBlacklist()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"blocked_ips": ips,
|
||||
"blocked_cidrs": cidrs,
|
||||
"total_ips": len(ips),
|
||||
"total_cidrs": len(cidrs),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
})
|
||||
}
|
||||
254
pkg/middleware/blacklist_test.go
Normal file
254
pkg/middleware/blacklist_test.go
Normal file
@ -0,0 +1,254 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPBlacklist_BlockIP(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block an IP
|
||||
err := bl.BlockIP("192.168.1.100", "Suspicious activity")
|
||||
if err != nil {
|
||||
t.Fatalf("BlockIP() error = %v", err)
|
||||
}
|
||||
|
||||
// Check if IP is blocked
|
||||
blocked, reason := bl.IsBlocked("192.168.1.100")
|
||||
if !blocked {
|
||||
t.Error("IP should be blocked")
|
||||
}
|
||||
if reason != "Suspicious activity" {
|
||||
t.Errorf("Reason = %q, want %q", reason, "Suspicious activity")
|
||||
}
|
||||
|
||||
// Check non-blocked IP
|
||||
blocked, _ = bl.IsBlocked("192.168.1.1")
|
||||
if blocked {
|
||||
t.Error("IP should not be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_BlockCIDR(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block a CIDR range
|
||||
err := bl.BlockCIDR("10.0.0.0/24", "Internal network blocked")
|
||||
if err != nil {
|
||||
t.Fatalf("BlockCIDR() error = %v", err)
|
||||
}
|
||||
|
||||
// Check IPs in range
|
||||
testIPs := []string{
|
||||
"10.0.0.1",
|
||||
"10.0.0.100",
|
||||
"10.0.0.254",
|
||||
}
|
||||
|
||||
for _, ip := range testIPs {
|
||||
blocked, _ := bl.IsBlocked(ip)
|
||||
if !blocked {
|
||||
t.Errorf("IP %s should be blocked by CIDR", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Check IP outside range
|
||||
blocked, _ := bl.IsBlocked("10.0.1.1")
|
||||
if blocked {
|
||||
t.Error("IP outside CIDR range should not be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_UnblockIP(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block and then unblock
|
||||
bl.BlockIP("192.168.1.100", "Test")
|
||||
|
||||
blocked, _ := bl.IsBlocked("192.168.1.100")
|
||||
if !blocked {
|
||||
t.Error("IP should be blocked")
|
||||
}
|
||||
|
||||
bl.UnblockIP("192.168.1.100")
|
||||
|
||||
blocked, _ = bl.IsBlocked("192.168.1.100")
|
||||
if blocked {
|
||||
t.Error("IP should be unblocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_UnblockCIDR(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block and then unblock CIDR
|
||||
bl.BlockCIDR("10.0.0.0/24", "Test")
|
||||
|
||||
blocked, _ := bl.IsBlocked("10.0.0.1")
|
||||
if !blocked {
|
||||
t.Error("IP should be blocked by CIDR")
|
||||
}
|
||||
|
||||
bl.UnblockCIDR("10.0.0.0/24")
|
||||
|
||||
blocked, _ = bl.IsBlocked("10.0.0.1")
|
||||
if blocked {
|
||||
t.Error("IP should be unblocked after CIDR removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_Middleware(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
bl.BlockIP("192.168.1.100", "Banned")
|
||||
|
||||
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
|
||||
// Blocked IP should get 403
|
||||
t.Run("BlockedIP", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "forbidden" {
|
||||
t.Errorf("Error = %v, want %q", response["error"], "forbidden")
|
||||
}
|
||||
})
|
||||
|
||||
// Allowed IP should succeed
|
||||
t.Run("AllowedIP", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPBlacklist_MiddlewareWithProxy(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: true})
|
||||
bl.BlockIP("203.0.113.1", "Blocked via proxy")
|
||||
|
||||
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test X-Forwarded-For
|
||||
t.Run("X-Forwarded-For", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
// Test X-Real-IP
|
||||
t.Run("X-Real-IP", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPBlacklist_StatsHandler(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
bl.BlockIP("192.168.1.100", "Test1")
|
||||
bl.BlockIP("192.168.1.101", "Test2")
|
||||
bl.BlockCIDR("10.0.0.0/24", "Test CIDR")
|
||||
|
||||
handler := bl.StatsHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/blacklist-stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var stats map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if int(stats["total_ips"].(float64)) != 2 {
|
||||
t.Errorf("total_ips = %v, want 2", stats["total_ips"])
|
||||
}
|
||||
|
||||
if int(stats["total_cidrs"].(float64)) != 1 {
|
||||
t.Errorf("total_cidrs = %v, want 1", stats["total_cidrs"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_GetBlacklist(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
bl.BlockIP("192.168.1.100", "")
|
||||
bl.BlockIP("192.168.1.101", "")
|
||||
bl.BlockCIDR("10.0.0.0/24", "")
|
||||
|
||||
ips, cidrs := bl.GetBlacklist()
|
||||
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("len(ips) = %d, want 2", len(ips))
|
||||
}
|
||||
|
||||
if len(cidrs) != 1 {
|
||||
t.Errorf("len(cidrs) = %d, want 1", len(cidrs))
|
||||
}
|
||||
|
||||
// Verify CIDR format
|
||||
if cidrs[0] != "10.0.0.0/24" {
|
||||
t.Errorf("CIDR = %q, want %q", cidrs[0], "10.0.0.0/24")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_InvalidIP(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
err := bl.BlockIP("invalid-ip", "Test")
|
||||
if err == nil {
|
||||
t.Error("BlockIP() should return error for invalid IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_InvalidCIDR(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
err := bl.BlockCIDR("invalid-cidr", "Test")
|
||||
if err == nil {
|
||||
t.Error("BlockCIDR() should return error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
@ -1,7 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -71,11 +73,11 @@ func (rl *RateLimiter) cleanupRoutine() {
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that applies rate limiting
|
||||
// Automatically handles X-Forwarded-For headers when behind a proxy
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Use IP address as the rate limit key
|
||||
// In production, you might want to use X-Forwarded-For or custom headers
|
||||
key := r.RemoteAddr
|
||||
// Extract client IP, handling proxy headers
|
||||
key := getClientIP(r)
|
||||
|
||||
limiter := rl.getLimiter(key)
|
||||
|
||||
@ -108,3 +110,115 @@ func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitInfo contains information about a specific IP's rate limit status
|
||||
type RateLimitInfo struct {
|
||||
IP string `json:"ip"`
|
||||
TokensRemaining float64 `json:"tokens_remaining"`
|
||||
Limit float64 `json:"limit"`
|
||||
Burst int `json:"burst"`
|
||||
}
|
||||
|
||||
// GetTrackedIPs returns all IPs currently being tracked by the rate limiter
|
||||
func (rl *RateLimiter) GetTrackedIPs() []string {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
ips := make([]string, 0, len(rl.limiters))
|
||||
for ip := range rl.limiters {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// GetRateLimitInfo returns rate limit information for a specific IP
|
||||
func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.limiters[ip]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Return default info for untracked IP
|
||||
return &RateLimitInfo{
|
||||
IP: ip,
|
||||
TokensRemaining: float64(rl.burst),
|
||||
Limit: float64(rl.rate),
|
||||
Burst: rl.burst,
|
||||
}
|
||||
}
|
||||
|
||||
return &RateLimitInfo{
|
||||
IP: ip,
|
||||
TokensRemaining: limiter.Tokens(),
|
||||
Limit: float64(rl.rate),
|
||||
Burst: rl.burst,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllRateLimitInfo returns rate limit information for all tracked IPs
|
||||
func (rl *RateLimiter) GetAllRateLimitInfo() []*RateLimitInfo {
|
||||
ips := rl.GetTrackedIPs()
|
||||
info := make([]*RateLimitInfo, 0, len(ips))
|
||||
|
||||
for _, ip := range ips {
|
||||
info = append(info, rl.GetRateLimitInfo(ip))
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// StatsHandler returns an HTTP handler that exposes rate limit statistics
|
||||
// Example: GET /rate-limit-stats
|
||||
func (rl *RateLimiter) StatsHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Support querying specific IP via ?ip=x.x.x.x
|
||||
if ip := r.URL.Query().Get("ip"); ip != "" {
|
||||
info := rl.GetRateLimitInfo(ip)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(info)
|
||||
return
|
||||
}
|
||||
|
||||
// Return all tracked IPs
|
||||
allInfo := rl.GetAllRateLimitInfo()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_tracked_ips": len(allInfo),
|
||||
"rate_limit_config": map[string]interface{}{
|
||||
"requests_per_second": float64(rl.rate),
|
||||
"burst": rl.burst,
|
||||
},
|
||||
"tracked_ips": allInfo,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP extracts the real client IP from the request
|
||||
// Handles X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (most common in production)
|
||||
// Format: X-Forwarded-For: client, proxy1, proxy2
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP (the original client)
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Check X-Real-IP header (used by some proxies like nginx)
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
// Remove port if present (format: "ip:port")
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
388
pkg/middleware/ratelimit_test.go
Normal file
388
pkg/middleware/ratelimit_test.go
Normal file
@ -0,0 +1,388 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
// Create rate limiter: 2 requests per second, burst of 2
|
||||
rl := NewRateLimiter(2, 2)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
|
||||
// First request should succeed
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("First request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Second request should succeed (within burst)
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Second request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("Third request should be rate limited: got %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
// Wait for rate limiter to refill
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// Request should succeed again
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Request after wait failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||
rl := NewRateLimiter(1, 1)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First IP
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
// Second IP
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.2:12345"
|
||||
|
||||
// Both should succeed (different IPs)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Errorf("First IP request failed: got %d, want %d", w1.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("Second IP request failed: got %d, want %d", w2.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
xRealIP string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "RemoteAddr only",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For single IP",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xForwardedFor: "203.0.113.1",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For multiple IPs",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xForwardedFor: "203.0.113.1, 10.0.0.2, 10.0.0.3",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xRealIP: "203.0.113.1",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xForwardedFor: "203.0.113.1",
|
||||
xRealIP: "203.0.113.2",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
remoteAddr: "[2001:db8::1]:12345",
|
||||
expectedIP: "[2001:db8::1]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
if tt.xRealIP != "" {
|
||||
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||
}
|
||||
|
||||
ip := getClientIP(req)
|
||||
if ip != tt.expectedIP {
|
||||
t.Errorf("getClientIP() = %q, want %q", ip, tt.expectedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterWithCustomKeyFunc(t *testing.T) {
|
||||
rl := NewRateLimiter(1, 1)
|
||||
|
||||
// Use user ID as key
|
||||
keyFunc := func(r *http.Request) string {
|
||||
userID := r.Header.Get("X-User-ID")
|
||||
if userID == "" {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return "user:" + userID
|
||||
}
|
||||
|
||||
handler := rl.MiddlewareWithKeyFunc(keyFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// User 1
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.Header.Set("X-User-ID", "user1")
|
||||
|
||||
// User 2
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.Header.Set("X-User-ID", "user2")
|
||||
|
||||
// Both users should succeed (different keys)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Errorf("User 1 request failed: got %d, want %d", w1.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("User 2 request failed: got %d, want %d", w2.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// User 1 second request should be rate limited
|
||||
w1 = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
if w1.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("User 1 second request should be rate limited: got %d, want %d", w1.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetTrackedIPs(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 10)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make requests from different IPs
|
||||
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||
for _, ip := range ips {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = ip + ":12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Check tracked IPs
|
||||
trackedIPs := rl.GetTrackedIPs()
|
||||
if len(trackedIPs) != len(ips) {
|
||||
t.Errorf("len(trackedIPs) = %d, want %d", len(trackedIPs), len(ips))
|
||||
}
|
||||
|
||||
// Verify all IPs are tracked
|
||||
ipMap := make(map[string]bool)
|
||||
for _, ip := range trackedIPs {
|
||||
ipMap[ip] = true
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if !ipMap[ip] {
|
||||
t.Errorf("IP %s should be tracked", ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetRateLimitInfo(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 5)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make a request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// Get rate limit info
|
||||
info := rl.GetRateLimitInfo("192.168.1.1")
|
||||
|
||||
if info.IP != "192.168.1.1" {
|
||||
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||
}
|
||||
|
||||
if info.Limit != 10.0 {
|
||||
t.Errorf("Limit = %f, want 10.0", info.Limit)
|
||||
}
|
||||
|
||||
if info.Burst != 5 {
|
||||
t.Errorf("Burst = %d, want 5", info.Burst)
|
||||
}
|
||||
|
||||
// Tokens should be less than burst after one request
|
||||
if info.TokensRemaining >= float64(info.Burst) {
|
||||
t.Errorf("TokensRemaining = %f, should be less than %d", info.TokensRemaining, info.Burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetRateLimitInfo_UntrackedIP(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 5)
|
||||
|
||||
// Get info for untracked IP (should return default)
|
||||
info := rl.GetRateLimitInfo("192.168.1.1")
|
||||
|
||||
if info.IP != "192.168.1.1" {
|
||||
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||
}
|
||||
|
||||
if info.TokensRemaining != float64(rl.burst) {
|
||||
t.Errorf("TokensRemaining = %f, want %d (full burst)", info.TokensRemaining, rl.burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetAllRateLimitInfo(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 10)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make requests from different IPs
|
||||
ips := []string{"192.168.1.1", "192.168.1.2"}
|
||||
for _, ip := range ips {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = ip + ":12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Get all rate limit info
|
||||
allInfo := rl.GetAllRateLimitInfo()
|
||||
|
||||
if len(allInfo) != len(ips) {
|
||||
t.Errorf("len(allInfo) = %d, want %d", len(allInfo), len(ips))
|
||||
}
|
||||
|
||||
// Verify each IP has info
|
||||
for _, info := range allInfo {
|
||||
found := false
|
||||
for _, ip := range ips {
|
||||
if info.IP == ip {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Unexpected IP in info: %s", info.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_StatsHandler(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 5)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make requests from different IPs
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.2:12345"
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
// Test stats handler (all IPs)
|
||||
t.Run("AllIPs", func(t *testing.T) {
|
||||
statsHandler := rl.StatsHandler()
|
||||
req := httptest.NewRequest("GET", "/rate-limit-stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
statsHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var stats map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if int(stats["total_tracked_ips"].(float64)) != 2 {
|
||||
t.Errorf("total_tracked_ips = %v, want 2", stats["total_tracked_ips"])
|
||||
}
|
||||
|
||||
config := stats["rate_limit_config"].(map[string]interface{})
|
||||
if config["requests_per_second"].(float64) != 10.0 {
|
||||
t.Errorf("requests_per_second = %v, want 10.0", config["requests_per_second"])
|
||||
}
|
||||
})
|
||||
|
||||
// Test stats handler (specific IP)
|
||||
t.Run("SpecificIP", func(t *testing.T) {
|
||||
statsHandler := rl.StatsHandler()
|
||||
req := httptest.NewRequest("GET", "/rate-limit-stats?ip=192.168.1.1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
statsHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var info RateLimitInfo
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &info); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if info.IP != "192.168.1.1" {
|
||||
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||
}
|
||||
})
|
||||
}
|
||||
273
pkg/middleware/sanitize_test.go
Normal file
273
pkg/middleware/sanitize_test.go
Normal file
@ -0,0 +1,273 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeXSS(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
contains string // String that should NOT be in output
|
||||
}{
|
||||
{
|
||||
name: "Script tag",
|
||||
input: "<script>alert(1)</script>",
|
||||
contains: "<script>",
|
||||
},
|
||||
{
|
||||
name: "JavaScript protocol",
|
||||
input: "javascript:alert(1)",
|
||||
contains: "javascript:",
|
||||
},
|
||||
{
|
||||
name: "Event handler",
|
||||
input: "<img onerror='alert(1)'>",
|
||||
contains: "onerror=",
|
||||
},
|
||||
{
|
||||
name: "Iframe",
|
||||
input: "<iframe src='evil.com'></iframe>",
|
||||
contains: "<iframe",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizer.Sanitize(tt.input)
|
||||
if result == tt.input {
|
||||
t.Errorf("Sanitize() did not modify input: %q", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeNullBytes(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
input := "hello\x00world"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
if result == input {
|
||||
t.Error("Null bytes should be removed")
|
||||
}
|
||||
|
||||
if len(result) >= len(input) {
|
||||
t.Errorf("Result length should be less than input: got %d, input %d", len(result), len(input))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeControlCharacters(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
// Include various control characters
|
||||
input := "hello\x01\x02world\x1F"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
if result == input {
|
||||
t.Error("Control characters should be removed")
|
||||
}
|
||||
|
||||
// Newlines, tabs, carriage returns should be preserved
|
||||
input2 := "hello\nworld\t\r"
|
||||
result2 := sanitizer.Sanitize(input2)
|
||||
|
||||
if result2 != input2 {
|
||||
t.Errorf("Safe control characters should be preserved: got %q, want %q", result2, input2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeMap(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"name": "<script>alert(1)</script>John",
|
||||
"email": "test@example.com",
|
||||
"nested": map[string]interface{}{
|
||||
"bio": "<iframe src='evil.com'>Bio</iframe>",
|
||||
},
|
||||
}
|
||||
|
||||
result := sanitizer.SanitizeMap(input)
|
||||
|
||||
// Check that script tag was removed/escaped
|
||||
name, ok := result["name"].(string)
|
||||
if !ok || name == input["name"] {
|
||||
t.Error("Name should be sanitized")
|
||||
}
|
||||
|
||||
// Check nested map
|
||||
nested, ok := result["nested"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Nested should still be a map")
|
||||
}
|
||||
|
||||
bio, ok := nested["bio"].(string)
|
||||
if !ok || bio == input["nested"].(map[string]interface{})["bio"] {
|
||||
t.Error("Nested bio should be sanitized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeMiddleware(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
handler := sanitizer.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check that query param was sanitized
|
||||
param := r.URL.Query().Get("q")
|
||||
if param == "<script>alert(1)</script>" {
|
||||
t.Error("Query param should be sanitized")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test?q=<script>alert(1)</script>", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Handler failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
contains string // String that should NOT be in output
|
||||
}{
|
||||
{
|
||||
name: "Path traversal",
|
||||
input: "../../../etc/passwd",
|
||||
contains: "..",
|
||||
},
|
||||
{
|
||||
name: "Absolute path",
|
||||
input: "/etc/passwd",
|
||||
contains: "/",
|
||||
},
|
||||
{
|
||||
name: "Windows path",
|
||||
input: "..\\..\\windows\\system32",
|
||||
contains: "\\",
|
||||
},
|
||||
{
|
||||
name: "Null byte",
|
||||
input: "file\x00.txt",
|
||||
contains: "\x00",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeFilename(tt.input)
|
||||
if result == tt.input {
|
||||
t.Errorf("SanitizeFilename() did not modify input: %q", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Uppercase",
|
||||
input: "TEST@EXAMPLE.COM",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "Whitespace",
|
||||
input: " test@example.com ",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "Null bytes",
|
||||
input: "test\x00@example.com",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeEmail(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeEmail() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "JavaScript protocol",
|
||||
input: "javascript:alert(1)",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Data protocol",
|
||||
input: "data:text/html,<script>alert(1)</script>",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL",
|
||||
input: "https://example.com",
|
||||
expected: "https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeURL(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeURL() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictSanitizer(t *testing.T) {
|
||||
sanitizer := StrictSanitizer()
|
||||
|
||||
input := "<b>Bold text</b> with <script>alert(1)</script>"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
// Should strip ALL HTML tags
|
||||
if result == input {
|
||||
t.Error("Strict sanitizer should modify input")
|
||||
}
|
||||
|
||||
// Should not contain any HTML tags
|
||||
if len(result) > 0 && (result[0] == '<' || result[len(result)-1] == '>') {
|
||||
t.Error("Result should not contain HTML tags")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxStringLength(t *testing.T) {
|
||||
sanitizer := &Sanitizer{
|
||||
MaxStringLength: 10,
|
||||
}
|
||||
|
||||
input := "This is a very long string that exceeds the maximum length"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
if len(result) != 10 {
|
||||
t.Errorf("Result length = %d, want 10", len(result))
|
||||
}
|
||||
|
||||
if result != input[:10] {
|
||||
t.Errorf("Result = %q, want %q", result, input[:10])
|
||||
}
|
||||
}
|
||||
126
pkg/middleware/sizelimit_test.go
Normal file
126
pkg/middleware/sizelimit_test.go
Normal file
@ -0,0 +1,126 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestSizeLimiter(t *testing.T) {
|
||||
// 1KB limit
|
||||
limiter := NewRequestSizeLimiter(1024)
|
||||
|
||||
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try to read body
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Small request (should succeed)
|
||||
t.Run("SmallRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 512)) // 512 bytes
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Small request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Check header
|
||||
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "1024" {
|
||||
t.Errorf("MaxRequestSizeHeader = %q, want %q", maxSize, "1024")
|
||||
}
|
||||
})
|
||||
|
||||
// Large request (should fail)
|
||||
t.Run("LargeRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 2048)) // 2KB
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("Large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestSizeLimiterDefault(t *testing.T) {
|
||||
// Default limiter (10MB)
|
||||
limiter := NewRequestSizeLimiter(0)
|
||||
|
||||
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewReader(make([]byte, 1024)))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Check default size
|
||||
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "10485760" {
|
||||
t.Errorf("Default MaxRequestSizeHeader = %q, want %q", maxSize, "10485760")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestSizeLimiterWithCustomSize(t *testing.T) {
|
||||
limiter := NewRequestSizeLimiter(1024)
|
||||
|
||||
// Premium users get 10MB, regular users get 1KB
|
||||
sizeFunc := func(r *http.Request) int64 {
|
||||
if r.Header.Get("X-User-Tier") == "premium" {
|
||||
return Size10MB
|
||||
}
|
||||
return 1024
|
||||
}
|
||||
|
||||
handler := limiter.MiddlewareWithCustomSize(sizeFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Regular user with large request (should fail)
|
||||
t.Run("RegularUserLargeRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 2048))
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("Regular user large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
})
|
||||
|
||||
// Premium user with large request (should succeed)
|
||||
t.Run("PremiumUserLargeRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 2048))
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
req.Header.Set("X-User-Tier", "premium")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Premium user large request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
231
pkg/server/shutdown_test.go
Normal file
231
pkg/server/shutdown_test.go
Normal file
@ -0,0 +1,231 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGracefulServerTrackRequests(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
})
|
||||
|
||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||
|
||||
// Start some requests
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait a bit for requests to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Check in-flight count
|
||||
inFlight := srv.InFlightRequests()
|
||||
if inFlight == 0 {
|
||||
t.Error("Should have in-flight requests")
|
||||
}
|
||||
|
||||
// Wait for all requests to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check that counter is back to zero
|
||||
inFlight = srv.InFlightRequests()
|
||||
if inFlight != 0 {
|
||||
t.Errorf("In-flight requests should be 0, got %d", inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
})
|
||||
|
||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||
|
||||
// Mark as shutting down
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
// Try to make a request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// Should get 503
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckHandler(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
})
|
||||
|
||||
handler := srv.HealthCheckHandler()
|
||||
|
||||
// Healthy
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if w.Body.String() != `{"status":"healthy"}` {
|
||||
t.Errorf("Unexpected body: %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
// Shutting down
|
||||
t.Run("ShuttingDown", func(t *testing.T) {
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadinessHandler(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
})
|
||||
|
||||
handler := srv.ReadinessHandler()
|
||||
|
||||
// Ready with no in-flight requests
|
||||
t.Run("Ready", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if body != `{"ready":true,"in_flight_requests":0}` {
|
||||
t.Errorf("Unexpected body: %s", body)
|
||||
}
|
||||
})
|
||||
|
||||
// Not ready during shutdown
|
||||
t.Run("NotReady", func(t *testing.T) {
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShutdownCallbacks(t *testing.T) {
|
||||
callbackExecuted := false
|
||||
|
||||
RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
callbackExecuted = true
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := executeShutdownCallbacks(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("executeShutdownCallbacks() error = %v", err)
|
||||
}
|
||||
|
||||
if !callbackExecuted {
|
||||
t.Error("Shutdown callback was not executed")
|
||||
}
|
||||
|
||||
// Reset for other tests
|
||||
shutdownCallbacks = nil
|
||||
}
|
||||
|
||||
func TestDrainRequests(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
DrainTimeout: 1 * time.Second,
|
||||
})
|
||||
|
||||
// Simulate in-flight requests
|
||||
srv.inFlightRequests.Add(3)
|
||||
|
||||
// Start draining in background
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Simulate requests completing
|
||||
srv.inFlightRequests.Add(-3)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := srv.drainRequests(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("drainRequests() error = %v", err)
|
||||
}
|
||||
|
||||
if srv.InFlightRequests() != 0 {
|
||||
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainRequestsTimeout(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
DrainTimeout: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Simulate in-flight requests that don't complete
|
||||
srv.inFlightRequests.Add(5)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := srv.drainRequests(ctx)
|
||||
if err == nil {
|
||||
t.Error("drainRequests() should timeout with error")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
srv.inFlightRequests.Add(-5)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
|
||||
// Including here for completeness of server tests
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user