mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
Added blacklist middleware
This commit is contained in:
parent
2a84652dba
commit
aeae9d7e0c
440
SECURITY_FEATURES.md
Normal file
440
SECURITY_FEATURES.md
Normal file
@ -0,0 +1,440 @@
|
|||||||
|
# Security Features: Blacklist & Rate Limit Inspection
|
||||||
|
|
||||||
|
## IP Blacklist
|
||||||
|
|
||||||
|
The IP blacklist middleware allows you to block specific IP addresses or CIDR ranges from accessing your application.
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
|
||||||
|
// Create blacklist (UseProxy=true if behind a proxy)
|
||||||
|
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||||
|
UseProxy: true, // Checks X-Forwarded-For and X-Real-IP headers
|
||||||
|
})
|
||||||
|
|
||||||
|
// Block individual IP
|
||||||
|
blacklist.BlockIP("192.168.1.100", "Suspicious activity detected")
|
||||||
|
|
||||||
|
// Block entire CIDR range
|
||||||
|
blacklist.BlockCIDR("10.0.0.0/8", "Private network blocked")
|
||||||
|
|
||||||
|
// Apply middleware
|
||||||
|
http.Handle("/api/", blacklist.Middleware(yourHandler))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Managing Blacklist
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Unblock an IP
|
||||||
|
blacklist.UnblockIP("192.168.1.100")
|
||||||
|
|
||||||
|
// Unblock a CIDR range
|
||||||
|
blacklist.UnblockCIDR("10.0.0.0/8")
|
||||||
|
|
||||||
|
// Get all blacklisted IPs and CIDRs
|
||||||
|
ips, cidrs := blacklist.GetBlacklist()
|
||||||
|
fmt.Printf("Blocked IPs: %v\n", ips)
|
||||||
|
fmt.Printf("Blocked CIDRs: %v\n", cidrs)
|
||||||
|
|
||||||
|
// Check if specific IP is blocked
|
||||||
|
blocked, reason := blacklist.IsBlocked("192.168.1.100")
|
||||||
|
if blocked {
|
||||||
|
fmt.Printf("IP blocked: %s\n", reason)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Blacklist Statistics Endpoint
|
||||||
|
|
||||||
|
Expose blacklist statistics via HTTP:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Add stats endpoint
|
||||||
|
http.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"blocked_ips": ["192.168.1.100", "192.168.1.101"],
|
||||||
|
"blocked_cidrs": ["10.0.0.0/8"],
|
||||||
|
"total_ips": 2,
|
||||||
|
"total_cidrs": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
// Create blacklist
|
||||||
|
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||||
|
UseProxy: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Block known malicious IPs
|
||||||
|
blacklist.BlockIP("203.0.113.1", "Known scanner")
|
||||||
|
blacklist.BlockCIDR("198.51.100.0/24", "Spam network")
|
||||||
|
|
||||||
|
// Create your router
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
// Protected routes
|
||||||
|
mux.Handle("/api/", blacklist.Middleware(apiHandler))
|
||||||
|
|
||||||
|
// Admin endpoint to manage blacklist
|
||||||
|
mux.HandleFunc("/admin/block-ip", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := r.URL.Query().Get("ip")
|
||||||
|
reason := r.URL.Query().Get("reason")
|
||||||
|
|
||||||
|
if err := blacklist.BlockIP(ip, reason); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, "Blocked %s: %s", ip, reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Stats endpoint
|
||||||
|
mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", mux)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Rate Limit Inspection
|
||||||
|
|
||||||
|
Monitor and inspect rate limit status per IP address in real-time.
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
|
||||||
|
// Create rate limiter (10 req/sec, burst of 20)
|
||||||
|
rateLimiter := middleware.NewRateLimiter(10, 20)
|
||||||
|
|
||||||
|
// Apply middleware
|
||||||
|
http.Handle("/api/", rateLimiter.Middleware(yourHandler))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Programmatic Inspection
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get all tracked IPs
|
||||||
|
trackedIPs := rateLimiter.GetTrackedIPs()
|
||||||
|
fmt.Printf("Currently tracking %d IPs\n", len(trackedIPs))
|
||||||
|
|
||||||
|
// Get rate limit info for specific IP
|
||||||
|
info := rateLimiter.GetRateLimitInfo("192.168.1.1")
|
||||||
|
fmt.Printf("IP: %s\n", info.IP)
|
||||||
|
fmt.Printf("Tokens Remaining: %.2f\n", info.TokensRemaining)
|
||||||
|
fmt.Printf("Limit: %.2f req/sec\n", info.Limit)
|
||||||
|
fmt.Printf("Burst: %d\n", info.Burst)
|
||||||
|
|
||||||
|
// Get info for all tracked IPs
|
||||||
|
allInfo := rateLimiter.GetAllRateLimitInfo()
|
||||||
|
for _, info := range allInfo {
|
||||||
|
fmt.Printf("%s: %.2f tokens remaining\n", info.IP, info.TokensRemaining)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rate Limit Stats Endpoint
|
||||||
|
|
||||||
|
Expose rate limit statistics via HTTP:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Add stats endpoint
|
||||||
|
http.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler())
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Response (all IPs):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total_tracked_ips": 3,
|
||||||
|
"rate_limit_config": {
|
||||||
|
"requests_per_second": 10,
|
||||||
|
"burst": 20
|
||||||
|
},
|
||||||
|
"tracked_ips": [
|
||||||
|
{
|
||||||
|
"ip": "192.168.1.1",
|
||||||
|
"tokens_remaining": 15.5,
|
||||||
|
"limit": 10,
|
||||||
|
"burst": 20
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ip": "192.168.1.2",
|
||||||
|
"tokens_remaining": 18.2,
|
||||||
|
"limit": 10,
|
||||||
|
"burst": 20
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Response (specific IP):**
|
||||||
|
```bash
|
||||||
|
GET /admin/rate-limit-stats?ip=192.168.1.1
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ip": "192.168.1.1",
|
||||||
|
"tokens_remaining": 15.5,
|
||||||
|
"limit": 10,
|
||||||
|
"burst": 20
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complete Integration Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Create rate limiter
|
||||||
|
rateLimiter := middleware.NewRateLimiter(10, 20)
|
||||||
|
|
||||||
|
// Create blacklist
|
||||||
|
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||||
|
UseProxy: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
// API handler with both middlewares (blacklist first, then rate limit)
|
||||||
|
apiHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"message": "Success",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Apply middleware chain: blacklist -> rate limit -> handler
|
||||||
|
mux.Handle("/api/", blacklist.Middleware(rateLimiter.Middleware(apiHandler)))
|
||||||
|
|
||||||
|
// Admin endpoints
|
||||||
|
mux.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler())
|
||||||
|
mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
|
||||||
|
|
||||||
|
// Custom monitoring endpoint
|
||||||
|
mux.HandleFunc("/admin/monitor", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get rate limit stats
|
||||||
|
rateLimitInfo := rateLimiter.GetAllRateLimitInfo()
|
||||||
|
|
||||||
|
// Get blacklist stats
|
||||||
|
blockedIPs, blockedCIDRs := blacklist.GetBlacklist()
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"rate_limits": rateLimitInfo,
|
||||||
|
"blacklist": map[string]interface{}{
|
||||||
|
"ips": blockedIPs,
|
||||||
|
"cidrs": blockedCIDRs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Dynamic blacklist management
|
||||||
|
mux.HandleFunc("/admin/block", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := r.URL.Query().Get("ip")
|
||||||
|
reason := r.URL.Query().Get("reason")
|
||||||
|
|
||||||
|
if ip == "" {
|
||||||
|
http.Error(w, "IP required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := blacklist.BlockIP(ip, reason); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "Blocked %s: %s", ip, reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/admin/unblock", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := r.URL.Query().Get("ip")
|
||||||
|
if ip == "" {
|
||||||
|
http.Error(w, "IP required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
blacklist.UnblockIP(ip)
|
||||||
|
fmt.Fprintf(w, "Unblocked %s", ip)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Auto-block IPs that exceed rate limit
|
||||||
|
mux.HandleFunc("/admin/auto-block-heavy-users", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
blocked := 0
|
||||||
|
|
||||||
|
for _, info := range rateLimiter.GetAllRateLimitInfo() {
|
||||||
|
// If tokens are very low, IP is making many requests
|
||||||
|
if info.TokensRemaining < 1.0 {
|
||||||
|
blacklist.BlockIP(info.IP, "Exceeded rate limit")
|
||||||
|
blocked++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "Blocked %d IPs exceeding rate limits", blocked)
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Println("Server starting on :8080")
|
||||||
|
fmt.Println("Rate limit stats: http://localhost:8080/admin/rate-limit-stats")
|
||||||
|
fmt.Println("Blacklist stats: http://localhost:8080/admin/blacklist-stats")
|
||||||
|
http.ListenAndServe(":8080", mux)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Monitoring Dashboard Example
|
||||||
|
|
||||||
|
Create a simple monitoring page:
|
||||||
|
|
||||||
|
```go
|
||||||
|
mux.HandleFunc("/admin/dashboard", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
html := `
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Security Dashboard</title>
|
||||||
|
<script>
|
||||||
|
async function loadStats() {
|
||||||
|
const rateLimitRes = await fetch('/admin/rate-limit-stats');
|
||||||
|
const rateLimitData = await rateLimitRes.json();
|
||||||
|
|
||||||
|
const blacklistRes = await fetch('/admin/blacklist-stats');
|
||||||
|
const blacklistData = await blacklistRes.json();
|
||||||
|
|
||||||
|
document.getElementById('rate-limit').innerHTML =
|
||||||
|
JSON.stringify(rateLimitData, null, 2);
|
||||||
|
document.getElementById('blacklist').innerHTML =
|
||||||
|
JSON.stringify(blacklistData, null, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
setInterval(loadStats, 5000); // Refresh every 5 seconds
|
||||||
|
loadStats();
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Security Dashboard</h1>
|
||||||
|
|
||||||
|
<h2>Rate Limits</h2>
|
||||||
|
<pre id="rate-limit">Loading...</pre>
|
||||||
|
|
||||||
|
<h2>Blacklist</h2>
|
||||||
|
<pre id="blacklist">Loading...</pre>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.Write([]byte(html))
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### 1. Proxy Configuration
|
||||||
|
Always set `UseProxy: true` when running behind a reverse proxy (nginx, Cloudflare, etc.):
|
||||||
|
```go
|
||||||
|
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
|
||||||
|
UseProxy: true, // Checks X-Forwarded-For headers
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Middleware Order
|
||||||
|
Apply blacklist before rate limiting to save resources:
|
||||||
|
```go
|
||||||
|
// Correct order: blacklist -> rate limit -> handler
|
||||||
|
handler := blacklist.Middleware(
|
||||||
|
rateLimiter.Middleware(yourHandler)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Secure Admin Endpoints
|
||||||
|
Protect admin endpoints with authentication:
|
||||||
|
```go
|
||||||
|
mux.Handle("/admin/", authMiddleware(adminHandler))
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Monitoring
|
||||||
|
Set up alerts when:
|
||||||
|
- Many IPs are being rate limited
|
||||||
|
- Blacklist grows too large
|
||||||
|
- Specific IPs are repeatedly blocked
|
||||||
|
|
||||||
|
### 5. Dynamic Response
|
||||||
|
Automatically block IPs that consistently exceed rate limits:
|
||||||
|
```go
|
||||||
|
// Check every minute
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
go func() {
|
||||||
|
for range ticker.C {
|
||||||
|
for _, info := range rateLimiter.GetAllRateLimitInfo() {
|
||||||
|
if info.TokensRemaining < 0.5 {
|
||||||
|
blacklist.BlockIP(info.IP, "Automated block: rate limit exceeded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. CIDR for Network Blocks
|
||||||
|
Use CIDR ranges to block entire networks efficiently:
|
||||||
|
```go
|
||||||
|
// Block entire subnets
|
||||||
|
blacklist.BlockCIDR("10.0.0.0/8", "Private network")
|
||||||
|
blacklist.BlockCIDR("192.168.0.0/16", "Local network")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### IPBlacklist
|
||||||
|
|
||||||
|
#### Methods
|
||||||
|
- `BlockIP(ip, reason string) error` - Block a single IP address
|
||||||
|
- `BlockCIDR(cidr, reason string) error` - Block a CIDR range
|
||||||
|
- `UnblockIP(ip string)` - Remove IP from blacklist
|
||||||
|
- `UnblockCIDR(cidr string)` - Remove CIDR from blacklist
|
||||||
|
- `IsBlocked(ip string) (blocked bool, reason string)` - Check if IP is blocked
|
||||||
|
- `GetBlacklist() (ips, cidrs []string)` - Get all blocked IPs and CIDRs
|
||||||
|
- `Middleware(next http.Handler) http.Handler` - HTTP middleware
|
||||||
|
- `StatsHandler() http.Handler` - HTTP handler for statistics
|
||||||
|
|
||||||
|
### RateLimiter
|
||||||
|
|
||||||
|
#### Methods
|
||||||
|
- `GetTrackedIPs() []string` - Get all tracked IP addresses
|
||||||
|
- `GetRateLimitInfo(ip string) *RateLimitInfo` - Get info for specific IP
|
||||||
|
- `GetAllRateLimitInfo() []*RateLimitInfo` - Get info for all tracked IPs
|
||||||
|
- `Middleware(next http.Handler) http.Handler` - HTTP middleware
|
||||||
|
- `StatsHandler() http.Handler` - HTTP handler for statistics
|
||||||
|
|
||||||
|
#### RateLimitInfo Structure
|
||||||
|
```go
|
||||||
|
type RateLimitInfo struct {
|
||||||
|
IP string `json:"ip"`
|
||||||
|
TokensRemaining float64 `json:"tokens_remaining"`
|
||||||
|
Limit float64 `json:"limit"`
|
||||||
|
Burst int `json:"burst"`
|
||||||
|
}
|
||||||
|
```
|
||||||
204
pkg/middleware/blacklist.go
Normal file
204
pkg/middleware/blacklist.go
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPBlacklist provides IP blocking functionality
|
||||||
|
type IPBlacklist struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
ips map[string]bool // Individual IPs
|
||||||
|
cidrs []*net.IPNet // CIDR ranges
|
||||||
|
reason map[string]string
|
||||||
|
useProxy bool // Whether to check X-Forwarded-For headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlacklistConfig configures the IP blacklist
|
||||||
|
type BlacklistConfig struct {
|
||||||
|
// UseProxy indicates whether to extract IP from X-Forwarded-For/X-Real-IP headers
|
||||||
|
UseProxy bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPBlacklist creates a new IP blacklist
|
||||||
|
func NewIPBlacklist(config BlacklistConfig) *IPBlacklist {
|
||||||
|
return &IPBlacklist{
|
||||||
|
ips: make(map[string]bool),
|
||||||
|
cidrs: make([]*net.IPNet, 0),
|
||||||
|
reason: make(map[string]string),
|
||||||
|
useProxy: config.UseProxy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockIP blocks a single IP address
|
||||||
|
func (bl *IPBlacklist) BlockIP(ip string, reason string) error {
|
||||||
|
// Validate IP
|
||||||
|
if net.ParseIP(ip) == nil {
|
||||||
|
return &net.ParseError{Type: "IP address", Text: ip}
|
||||||
|
}
|
||||||
|
|
||||||
|
bl.mu.Lock()
|
||||||
|
defer bl.mu.Unlock()
|
||||||
|
|
||||||
|
bl.ips[ip] = true
|
||||||
|
if reason != "" {
|
||||||
|
bl.reason[ip] = reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockCIDR blocks an IP range using CIDR notation
|
||||||
|
func (bl *IPBlacklist) BlockCIDR(cidr string, reason string) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
bl.mu.Lock()
|
||||||
|
defer bl.mu.Unlock()
|
||||||
|
|
||||||
|
bl.cidrs = append(bl.cidrs, ipNet)
|
||||||
|
if reason != "" {
|
||||||
|
bl.reason[cidr] = reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnblockIP removes an IP from the blacklist
|
||||||
|
func (bl *IPBlacklist) UnblockIP(ip string) {
|
||||||
|
bl.mu.Lock()
|
||||||
|
defer bl.mu.Unlock()
|
||||||
|
|
||||||
|
delete(bl.ips, ip)
|
||||||
|
delete(bl.reason, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnblockCIDR removes a CIDR range from the blacklist
|
||||||
|
func (bl *IPBlacklist) UnblockCIDR(cidr string) {
|
||||||
|
bl.mu.Lock()
|
||||||
|
defer bl.mu.Unlock()
|
||||||
|
|
||||||
|
// Find and remove the CIDR
|
||||||
|
for i, ipNet := range bl.cidrs {
|
||||||
|
if ipNet.String() == cidr {
|
||||||
|
bl.cidrs = append(bl.cidrs[:i], bl.cidrs[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(bl.reason, cidr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlocked checks if an IP is blacklisted
|
||||||
|
func (bl *IPBlacklist) IsBlocked(ip string) (bool, string) {
|
||||||
|
bl.mu.RLock()
|
||||||
|
defer bl.mu.RUnlock()
|
||||||
|
|
||||||
|
// Check individual IPs
|
||||||
|
if bl.ips[ip] {
|
||||||
|
return true, bl.reason[ip]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check CIDR ranges
|
||||||
|
parsedIP := net.ParseIP(ip)
|
||||||
|
if parsedIP == nil {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, ipNet := range bl.cidrs {
|
||||||
|
if ipNet.Contains(parsedIP) {
|
||||||
|
cidr := ipNet.String()
|
||||||
|
// Try to find reason by CIDR or by index
|
||||||
|
if reason, ok := bl.reason[cidr]; ok {
|
||||||
|
return true, reason
|
||||||
|
}
|
||||||
|
// Check if reason was stored by original CIDR string
|
||||||
|
for key, reason := range bl.reason {
|
||||||
|
if strings.Contains(key, "/") && key == cidr {
|
||||||
|
return true, reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Return true even if no reason found
|
||||||
|
if i < len(bl.cidrs) {
|
||||||
|
return true, ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlacklist returns all blacklisted IPs and CIDRs
|
||||||
|
func (bl *IPBlacklist) GetBlacklist() (ips []string, cidrs []string) {
|
||||||
|
bl.mu.RLock()
|
||||||
|
defer bl.mu.RUnlock()
|
||||||
|
|
||||||
|
ips = make([]string, 0, len(bl.ips))
|
||||||
|
for ip := range bl.ips {
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
cidrs = make([]string, 0, len(bl.cidrs))
|
||||||
|
for _, ipNet := range bl.cidrs {
|
||||||
|
cidrs = append(cidrs, ipNet.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips, cidrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns an HTTP middleware that blocks blacklisted IPs
|
||||||
|
func (bl *IPBlacklist) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var clientIP string
|
||||||
|
if bl.useProxy {
|
||||||
|
clientIP = getClientIP(r)
|
||||||
|
// Clean up IPv6 brackets if present
|
||||||
|
clientIP = strings.Trim(clientIP, "[]")
|
||||||
|
} else {
|
||||||
|
// Extract IP from RemoteAddr
|
||||||
|
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||||
|
clientIP = r.RemoteAddr[:idx]
|
||||||
|
} else {
|
||||||
|
clientIP = r.RemoteAddr
|
||||||
|
}
|
||||||
|
clientIP = strings.Trim(clientIP, "[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
blocked, reason := bl.IsBlocked(clientIP)
|
||||||
|
if blocked {
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"error": "forbidden",
|
||||||
|
"message": "Access denied",
|
||||||
|
}
|
||||||
|
if reason != "" {
|
||||||
|
response["reason"] = reason
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatsHandler returns an HTTP handler that shows blacklist statistics
|
||||||
|
func (bl *IPBlacklist) StatsHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ips, cidrs := bl.GetBlacklist()
|
||||||
|
|
||||||
|
stats := map[string]interface{}{
|
||||||
|
"blocked_ips": ips,
|
||||||
|
"blocked_cidrs": cidrs,
|
||||||
|
"total_ips": len(ips),
|
||||||
|
"total_cidrs": len(cidrs),
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(stats)
|
||||||
|
})
|
||||||
|
}
|
||||||
254
pkg/middleware/blacklist_test.go
Normal file
254
pkg/middleware/blacklist_test.go
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIPBlacklist_BlockIP(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
|
||||||
|
// Block an IP
|
||||||
|
err := bl.BlockIP("192.168.1.100", "Suspicious activity")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlockIP() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if IP is blocked
|
||||||
|
blocked, reason := bl.IsBlocked("192.168.1.100")
|
||||||
|
if !blocked {
|
||||||
|
t.Error("IP should be blocked")
|
||||||
|
}
|
||||||
|
if reason != "Suspicious activity" {
|
||||||
|
t.Errorf("Reason = %q, want %q", reason, "Suspicious activity")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check non-blocked IP
|
||||||
|
blocked, _ = bl.IsBlocked("192.168.1.1")
|
||||||
|
if blocked {
|
||||||
|
t.Error("IP should not be blocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_BlockCIDR(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
|
||||||
|
// Block a CIDR range
|
||||||
|
err := bl.BlockCIDR("10.0.0.0/24", "Internal network blocked")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlockCIDR() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check IPs in range
|
||||||
|
testIPs := []string{
|
||||||
|
"10.0.0.1",
|
||||||
|
"10.0.0.100",
|
||||||
|
"10.0.0.254",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range testIPs {
|
||||||
|
blocked, _ := bl.IsBlocked(ip)
|
||||||
|
if !blocked {
|
||||||
|
t.Errorf("IP %s should be blocked by CIDR", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check IP outside range
|
||||||
|
blocked, _ := bl.IsBlocked("10.0.1.1")
|
||||||
|
if blocked {
|
||||||
|
t.Error("IP outside CIDR range should not be blocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_UnblockIP(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
|
||||||
|
// Block and then unblock
|
||||||
|
bl.BlockIP("192.168.1.100", "Test")
|
||||||
|
|
||||||
|
blocked, _ := bl.IsBlocked("192.168.1.100")
|
||||||
|
if !blocked {
|
||||||
|
t.Error("IP should be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
bl.UnblockIP("192.168.1.100")
|
||||||
|
|
||||||
|
blocked, _ = bl.IsBlocked("192.168.1.100")
|
||||||
|
if blocked {
|
||||||
|
t.Error("IP should be unblocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_UnblockCIDR(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
|
||||||
|
// Block and then unblock CIDR
|
||||||
|
bl.BlockCIDR("10.0.0.0/24", "Test")
|
||||||
|
|
||||||
|
blocked, _ := bl.IsBlocked("10.0.0.1")
|
||||||
|
if !blocked {
|
||||||
|
t.Error("IP should be blocked by CIDR")
|
||||||
|
}
|
||||||
|
|
||||||
|
bl.UnblockCIDR("10.0.0.0/24")
|
||||||
|
|
||||||
|
blocked, _ = bl.IsBlocked("10.0.0.1")
|
||||||
|
if blocked {
|
||||||
|
t.Error("IP should be unblocked after CIDR removal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_Middleware(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
bl.BlockIP("192.168.1.100", "Banned")
|
||||||
|
|
||||||
|
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Blocked IP should get 403
|
||||||
|
t.Run("BlockedIP", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Fatalf("Failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response["error"] != "forbidden" {
|
||||||
|
t.Errorf("Error = %v, want %q", response["error"], "forbidden")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Allowed IP should succeed
|
||||||
|
t.Run("AllowedIP", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_MiddlewareWithProxy(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: true})
|
||||||
|
bl.BlockIP("203.0.113.1", "Blocked via proxy")
|
||||||
|
|
||||||
|
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Test X-Forwarded-For
|
||||||
|
t.Run("X-Forwarded-For", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:12345"
|
||||||
|
req.Header.Set("X-Forwarded-For", "203.0.113.1")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test X-Real-IP
|
||||||
|
t.Run("X-Real-IP", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:12345"
|
||||||
|
req.Header.Set("X-Real-IP", "203.0.113.1")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_StatsHandler(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
bl.BlockIP("192.168.1.100", "Test1")
|
||||||
|
bl.BlockIP("192.168.1.101", "Test2")
|
||||||
|
bl.BlockCIDR("10.0.0.0/24", "Test CIDR")
|
||||||
|
|
||||||
|
handler := bl.StatsHandler()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/blacklist-stats", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stats map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||||
|
t.Fatalf("Failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(stats["total_ips"].(float64)) != 2 {
|
||||||
|
t.Errorf("total_ips = %v, want 2", stats["total_ips"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(stats["total_cidrs"].(float64)) != 1 {
|
||||||
|
t.Errorf("total_cidrs = %v, want 1", stats["total_cidrs"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_GetBlacklist(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
bl.BlockIP("192.168.1.100", "")
|
||||||
|
bl.BlockIP("192.168.1.101", "")
|
||||||
|
bl.BlockCIDR("10.0.0.0/24", "")
|
||||||
|
|
||||||
|
ips, cidrs := bl.GetBlacklist()
|
||||||
|
|
||||||
|
if len(ips) != 2 {
|
||||||
|
t.Errorf("len(ips) = %d, want 2", len(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cidrs) != 1 {
|
||||||
|
t.Errorf("len(cidrs) = %d, want 1", len(cidrs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify CIDR format
|
||||||
|
if cidrs[0] != "10.0.0.0/24" {
|
||||||
|
t.Errorf("CIDR = %q, want %q", cidrs[0], "10.0.0.0/24")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_InvalidIP(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
|
||||||
|
err := bl.BlockIP("invalid-ip", "Test")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("BlockIP() should return error for invalid IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBlacklist_InvalidCIDR(t *testing.T) {
|
||||||
|
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||||
|
|
||||||
|
err := bl.BlockCIDR("invalid-cidr", "Test")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("BlockCIDR() should return error for invalid CIDR")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,7 +1,9 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -71,11 +73,11 @@ func (rl *RateLimiter) cleanupRoutine() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Middleware returns an HTTP middleware that applies rate limiting
|
// Middleware returns an HTTP middleware that applies rate limiting
|
||||||
|
// Automatically handles X-Forwarded-For headers when behind a proxy
|
||||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Use IP address as the rate limit key
|
// Extract client IP, handling proxy headers
|
||||||
// In production, you might want to use X-Forwarded-For or custom headers
|
key := getClientIP(r)
|
||||||
key := r.RemoteAddr
|
|
||||||
|
|
||||||
limiter := rl.getLimiter(key)
|
limiter := rl.getLimiter(key)
|
||||||
|
|
||||||
@ -108,3 +110,115 @@ func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string)
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RateLimitInfo contains information about a specific IP's rate limit status
|
||||||
|
type RateLimitInfo struct {
|
||||||
|
IP string `json:"ip"`
|
||||||
|
TokensRemaining float64 `json:"tokens_remaining"`
|
||||||
|
Limit float64 `json:"limit"`
|
||||||
|
Burst int `json:"burst"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTrackedIPs returns all IPs currently being tracked by the rate limiter
|
||||||
|
func (rl *RateLimiter) GetTrackedIPs() []string {
|
||||||
|
rl.mu.RLock()
|
||||||
|
defer rl.mu.RUnlock()
|
||||||
|
|
||||||
|
ips := make([]string, 0, len(rl.limiters))
|
||||||
|
for ip := range rl.limiters {
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRateLimitInfo returns rate limit information for a specific IP
|
||||||
|
func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo {
|
||||||
|
rl.mu.RLock()
|
||||||
|
limiter, exists := rl.limiters[ip]
|
||||||
|
rl.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
// Return default info for untracked IP
|
||||||
|
return &RateLimitInfo{
|
||||||
|
IP: ip,
|
||||||
|
TokensRemaining: float64(rl.burst),
|
||||||
|
Limit: float64(rl.rate),
|
||||||
|
Burst: rl.burst,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RateLimitInfo{
|
||||||
|
IP: ip,
|
||||||
|
TokensRemaining: limiter.Tokens(),
|
||||||
|
Limit: float64(rl.rate),
|
||||||
|
Burst: rl.burst,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllRateLimitInfo returns rate limit information for all tracked IPs
|
||||||
|
func (rl *RateLimiter) GetAllRateLimitInfo() []*RateLimitInfo {
|
||||||
|
ips := rl.GetTrackedIPs()
|
||||||
|
info := make([]*RateLimitInfo, 0, len(ips))
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
info = append(info, rl.GetRateLimitInfo(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatsHandler returns an HTTP handler that exposes rate limit statistics
|
||||||
|
// Example: GET /rate-limit-stats
|
||||||
|
func (rl *RateLimiter) StatsHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Support querying specific IP via ?ip=x.x.x.x
|
||||||
|
if ip := r.URL.Query().Get("ip"); ip != "" {
|
||||||
|
info := rl.GetRateLimitInfo(ip)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(info)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return all tracked IPs
|
||||||
|
allInfo := rl.GetAllRateLimitInfo()
|
||||||
|
|
||||||
|
stats := map[string]interface{}{
|
||||||
|
"total_tracked_ips": len(allInfo),
|
||||||
|
"rate_limit_config": map[string]interface{}{
|
||||||
|
"requests_per_second": float64(rl.rate),
|
||||||
|
"burst": rl.burst,
|
||||||
|
},
|
||||||
|
"tracked_ips": allInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(stats)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientIP extracts the real client IP from the request
|
||||||
|
// Handles X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// Check X-Forwarded-For header (most common in production)
|
||||||
|
// Format: X-Forwarded-For: client, proxy1, proxy2
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
// Take the first IP (the original client)
|
||||||
|
if idx := strings.Index(xff, ","); idx != -1 {
|
||||||
|
return strings.TrimSpace(xff[:idx])
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(xff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check X-Real-IP header (used by some proxies like nginx)
|
||||||
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||||
|
return strings.TrimSpace(xri)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to RemoteAddr
|
||||||
|
// Remove port if present (format: "ip:port")
|
||||||
|
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||||
|
return r.RemoteAddr[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
|||||||
388
pkg/middleware/ratelimit_test.go
Normal file
388
pkg/middleware/ratelimit_test.go
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimiter(t *testing.T) {
|
||||||
|
// Create rate limiter: 2 requests per second, burst of 2
|
||||||
|
rl := NewRateLimiter(2, 2)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
// First request should succeed
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("First request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second request should succeed (within burst)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Second request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third request should be rate limited
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("Third request should be rate limited: got %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for rate limiter to refill
|
||||||
|
time.Sleep(600 * time.Millisecond)
|
||||||
|
|
||||||
|
// Request should succeed again
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Request after wait failed: got %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(1, 1)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// First IP
|
||||||
|
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req1.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
|
||||||
|
// Second IP
|
||||||
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req2.RemoteAddr = "192.168.1.2:12345"
|
||||||
|
|
||||||
|
// Both should succeed (different IPs)
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w1, req1)
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
if w1.Code != http.StatusOK {
|
||||||
|
t.Errorf("First IP request failed: got %d, want %d", w1.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w2.Code != http.StatusOK {
|
||||||
|
t.Errorf("Second IP request failed: got %d, want %d", w2.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClientIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xForwardedFor string
|
||||||
|
xRealIP string
|
||||||
|
expectedIP string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RemoteAddr only",
|
||||||
|
remoteAddr: "192.168.1.1:12345",
|
||||||
|
expectedIP: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For single IP",
|
||||||
|
remoteAddr: "10.0.0.1:12345",
|
||||||
|
xForwardedFor: "203.0.113.1",
|
||||||
|
expectedIP: "203.0.113.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For multiple IPs",
|
||||||
|
remoteAddr: "10.0.0.1:12345",
|
||||||
|
xForwardedFor: "203.0.113.1, 10.0.0.2, 10.0.0.3",
|
||||||
|
expectedIP: "203.0.113.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Real-IP",
|
||||||
|
remoteAddr: "10.0.0.1:12345",
|
||||||
|
xRealIP: "203.0.113.1",
|
||||||
|
expectedIP: "203.0.113.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
||||||
|
remoteAddr: "10.0.0.1:12345",
|
||||||
|
xForwardedFor: "203.0.113.1",
|
||||||
|
xRealIP: "203.0.113.2",
|
||||||
|
expectedIP: "203.0.113.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
remoteAddr: "[2001:db8::1]:12345",
|
||||||
|
expectedIP: "[2001:db8::1]",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
|
||||||
|
if tt.xForwardedFor != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||||
|
}
|
||||||
|
if tt.xRealIP != "" {
|
||||||
|
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := getClientIP(req)
|
||||||
|
if ip != tt.expectedIP {
|
||||||
|
t.Errorf("getClientIP() = %q, want %q", ip, tt.expectedIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWithCustomKeyFunc(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(1, 1)
|
||||||
|
|
||||||
|
// Use user ID as key
|
||||||
|
keyFunc := func(r *http.Request) string {
|
||||||
|
userID := r.Header.Get("X-User-ID")
|
||||||
|
if userID == "" {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return "user:" + userID
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := rl.MiddlewareWithKeyFunc(keyFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// User 1
|
||||||
|
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req1.Header.Set("X-User-ID", "user1")
|
||||||
|
|
||||||
|
// User 2
|
||||||
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req2.Header.Set("X-User-ID", "user2")
|
||||||
|
|
||||||
|
// Both users should succeed (different keys)
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w1, req1)
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
if w1.Code != http.StatusOK {
|
||||||
|
t.Errorf("User 1 request failed: got %d, want %d", w1.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w2.Code != http.StatusOK {
|
||||||
|
t.Errorf("User 2 request failed: got %d, want %d", w2.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User 1 second request should be rate limited
|
||||||
|
w1 = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w1, req1)
|
||||||
|
|
||||||
|
if w1.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("User 1 second request should be rate limited: got %d, want %d", w1.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_GetTrackedIPs(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(10, 10)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make requests from different IPs
|
||||||
|
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||||
|
for _, ip := range ips {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = ip + ":12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check tracked IPs
|
||||||
|
trackedIPs := rl.GetTrackedIPs()
|
||||||
|
if len(trackedIPs) != len(ips) {
|
||||||
|
t.Errorf("len(trackedIPs) = %d, want %d", len(trackedIPs), len(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all IPs are tracked
|
||||||
|
ipMap := make(map[string]bool)
|
||||||
|
for _, ip := range trackedIPs {
|
||||||
|
ipMap[ip] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
if !ipMap[ip] {
|
||||||
|
t.Errorf("IP %s should be tracked", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_GetRateLimitInfo(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(10, 5)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make a request
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Get rate limit info
|
||||||
|
info := rl.GetRateLimitInfo("192.168.1.1")
|
||||||
|
|
||||||
|
if info.IP != "192.168.1.1" {
|
||||||
|
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Limit != 10.0 {
|
||||||
|
t.Errorf("Limit = %f, want 10.0", info.Limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Burst != 5 {
|
||||||
|
t.Errorf("Burst = %d, want 5", info.Burst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokens should be less than burst after one request
|
||||||
|
if info.TokensRemaining >= float64(info.Burst) {
|
||||||
|
t.Errorf("TokensRemaining = %f, should be less than %d", info.TokensRemaining, info.Burst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_GetRateLimitInfo_UntrackedIP(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(10, 5)
|
||||||
|
|
||||||
|
// Get info for untracked IP (should return default)
|
||||||
|
info := rl.GetRateLimitInfo("192.168.1.1")
|
||||||
|
|
||||||
|
if info.IP != "192.168.1.1" {
|
||||||
|
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.TokensRemaining != float64(rl.burst) {
|
||||||
|
t.Errorf("TokensRemaining = %f, want %d (full burst)", info.TokensRemaining, rl.burst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_GetAllRateLimitInfo(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(10, 10)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make requests from different IPs
|
||||||
|
ips := []string{"192.168.1.1", "192.168.1.2"}
|
||||||
|
for _, ip := range ips {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = ip + ":12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all rate limit info
|
||||||
|
allInfo := rl.GetAllRateLimitInfo()
|
||||||
|
|
||||||
|
if len(allInfo) != len(ips) {
|
||||||
|
t.Errorf("len(allInfo) = %d, want %d", len(allInfo), len(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify each IP has info
|
||||||
|
for _, info := range allInfo {
|
||||||
|
found := false
|
||||||
|
for _, ip := range ips {
|
||||||
|
if info.IP == ip {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Unexpected IP in info: %s", info.IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_StatsHandler(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(10, 5)
|
||||||
|
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Make requests from different IPs
|
||||||
|
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req1.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w1, req1)
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req2.RemoteAddr = "192.168.1.2:12345"
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
// Test stats handler (all IPs)
|
||||||
|
t.Run("AllIPs", func(t *testing.T) {
|
||||||
|
statsHandler := rl.StatsHandler()
|
||||||
|
req := httptest.NewRequest("GET", "/rate-limit-stats", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
statsHandler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stats map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||||
|
t.Fatalf("Failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(stats["total_tracked_ips"].(float64)) != 2 {
|
||||||
|
t.Errorf("total_tracked_ips = %v, want 2", stats["total_tracked_ips"])
|
||||||
|
}
|
||||||
|
|
||||||
|
config := stats["rate_limit_config"].(map[string]interface{})
|
||||||
|
if config["requests_per_second"].(float64) != 10.0 {
|
||||||
|
t.Errorf("requests_per_second = %v, want 10.0", config["requests_per_second"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test stats handler (specific IP)
|
||||||
|
t.Run("SpecificIP", func(t *testing.T) {
|
||||||
|
statsHandler := rl.StatsHandler()
|
||||||
|
req := httptest.NewRequest("GET", "/rate-limit-stats?ip=192.168.1.1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
statsHandler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var info RateLimitInfo
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &info); err != nil {
|
||||||
|
t.Fatalf("Failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IP != "192.168.1.1" {
|
||||||
|
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
273
pkg/middleware/sanitize_test.go
Normal file
273
pkg/middleware/sanitize_test.go
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeXSS(t *testing.T) {
|
||||||
|
sanitizer := DefaultSanitizer()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
contains string // String that should NOT be in output
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Script tag",
|
||||||
|
input: "<script>alert(1)</script>",
|
||||||
|
contains: "<script>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JavaScript protocol",
|
||||||
|
input: "javascript:alert(1)",
|
||||||
|
contains: "javascript:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Event handler",
|
||||||
|
input: "<img onerror='alert(1)'>",
|
||||||
|
contains: "onerror=",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Iframe",
|
||||||
|
input: "<iframe src='evil.com'></iframe>",
|
||||||
|
contains: "<iframe",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := sanitizer.Sanitize(tt.input)
|
||||||
|
if result == tt.input {
|
||||||
|
t.Errorf("Sanitize() did not modify input: %q", tt.input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeNullBytes(t *testing.T) {
|
||||||
|
sanitizer := DefaultSanitizer()
|
||||||
|
|
||||||
|
input := "hello\x00world"
|
||||||
|
result := sanitizer.Sanitize(input)
|
||||||
|
|
||||||
|
if result == input {
|
||||||
|
t.Error("Null bytes should be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) >= len(input) {
|
||||||
|
t.Errorf("Result length should be less than input: got %d, input %d", len(result), len(input))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeControlCharacters(t *testing.T) {
|
||||||
|
sanitizer := DefaultSanitizer()
|
||||||
|
|
||||||
|
// Include various control characters
|
||||||
|
input := "hello\x01\x02world\x1F"
|
||||||
|
result := sanitizer.Sanitize(input)
|
||||||
|
|
||||||
|
if result == input {
|
||||||
|
t.Error("Control characters should be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Newlines, tabs, carriage returns should be preserved
|
||||||
|
input2 := "hello\nworld\t\r"
|
||||||
|
result2 := sanitizer.Sanitize(input2)
|
||||||
|
|
||||||
|
if result2 != input2 {
|
||||||
|
t.Errorf("Safe control characters should be preserved: got %q, want %q", result2, input2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeMap(t *testing.T) {
|
||||||
|
sanitizer := DefaultSanitizer()
|
||||||
|
|
||||||
|
input := map[string]interface{}{
|
||||||
|
"name": "<script>alert(1)</script>John",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"nested": map[string]interface{}{
|
||||||
|
"bio": "<iframe src='evil.com'>Bio</iframe>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := sanitizer.SanitizeMap(input)
|
||||||
|
|
||||||
|
// Check that script tag was removed/escaped
|
||||||
|
name, ok := result["name"].(string)
|
||||||
|
if !ok || name == input["name"] {
|
||||||
|
t.Error("Name should be sanitized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check nested map
|
||||||
|
nested, ok := result["nested"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Nested should still be a map")
|
||||||
|
}
|
||||||
|
|
||||||
|
bio, ok := nested["bio"].(string)
|
||||||
|
if !ok || bio == input["nested"].(map[string]interface{})["bio"] {
|
||||||
|
t.Error("Nested bio should be sanitized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeMiddleware(t *testing.T) {
|
||||||
|
sanitizer := DefaultSanitizer()
|
||||||
|
|
||||||
|
handler := sanitizer.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check that query param was sanitized
|
||||||
|
param := r.URL.Query().Get("q")
|
||||||
|
if param == "<script>alert(1)</script>" {
|
||||||
|
t.Error("Query param should be sanitized")
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test?q=<script>alert(1)</script>", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Handler failed: got %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeFilename(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
contains string // String that should NOT be in output
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Path traversal",
|
||||||
|
input: "../../../etc/passwd",
|
||||||
|
contains: "..",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Absolute path",
|
||||||
|
input: "/etc/passwd",
|
||||||
|
contains: "/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Windows path",
|
||||||
|
input: "..\\..\\windows\\system32",
|
||||||
|
contains: "\\",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Null byte",
|
||||||
|
input: "file\x00.txt",
|
||||||
|
contains: "\x00",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeFilename(tt.input)
|
||||||
|
if result == tt.input {
|
||||||
|
t.Errorf("SanitizeFilename() did not modify input: %q", tt.input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeEmail(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Uppercase",
|
||||||
|
input: "TEST@EXAMPLE.COM",
|
||||||
|
expected: "test@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Whitespace",
|
||||||
|
input: " test@example.com ",
|
||||||
|
expected: "test@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Null bytes",
|
||||||
|
input: "test\x00@example.com",
|
||||||
|
expected: "test@example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeEmail(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeEmail() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "JavaScript protocol",
|
||||||
|
input: "javascript:alert(1)",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Data protocol",
|
||||||
|
input: "data:text/html,<script>alert(1)</script>",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid HTTP URL",
|
||||||
|
input: "https://example.com",
|
||||||
|
expected: "https://example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeURL(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeURL() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStrictSanitizer(t *testing.T) {
|
||||||
|
sanitizer := StrictSanitizer()
|
||||||
|
|
||||||
|
input := "<b>Bold text</b> with <script>alert(1)</script>"
|
||||||
|
result := sanitizer.Sanitize(input)
|
||||||
|
|
||||||
|
// Should strip ALL HTML tags
|
||||||
|
if result == input {
|
||||||
|
t.Error("Strict sanitizer should modify input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not contain any HTML tags
|
||||||
|
if len(result) > 0 && (result[0] == '<' || result[len(result)-1] == '>') {
|
||||||
|
t.Error("Result should not contain HTML tags")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxStringLength(t *testing.T) {
|
||||||
|
sanitizer := &Sanitizer{
|
||||||
|
MaxStringLength: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
input := "This is a very long string that exceeds the maximum length"
|
||||||
|
result := sanitizer.Sanitize(input)
|
||||||
|
|
||||||
|
if len(result) != 10 {
|
||||||
|
t.Errorf("Result length = %d, want 10", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != input[:10] {
|
||||||
|
t.Errorf("Result = %q, want %q", result, input[:10])
|
||||||
|
}
|
||||||
|
}
|
||||||
126
pkg/middleware/sizelimit_test.go
Normal file
126
pkg/middleware/sizelimit_test.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestSizeLimiter(t *testing.T) {
|
||||||
|
// 1KB limit
|
||||||
|
limiter := NewRequestSizeLimiter(1024)
|
||||||
|
|
||||||
|
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Try to read body
|
||||||
|
_, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Small request (should succeed)
|
||||||
|
t.Run("SmallRequest", func(t *testing.T) {
|
||||||
|
body := bytes.NewReader(make([]byte, 512)) // 512 bytes
|
||||||
|
req := httptest.NewRequest("POST", "/test", body)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Small request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check header
|
||||||
|
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "1024" {
|
||||||
|
t.Errorf("MaxRequestSizeHeader = %q, want %q", maxSize, "1024")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Large request (should fail)
|
||||||
|
t.Run("LargeRequest", func(t *testing.T) {
|
||||||
|
body := bytes.NewReader(make([]byte, 2048)) // 2KB
|
||||||
|
req := httptest.NewRequest("POST", "/test", body)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusRequestEntityTooLarge {
|
||||||
|
t.Errorf("Large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSizeLimiterDefault(t *testing.T) {
|
||||||
|
// Default limiter (10MB)
|
||||||
|
limiter := NewRequestSizeLimiter(0)
|
||||||
|
|
||||||
|
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", bytes.NewReader(make([]byte, 1024)))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check default size
|
||||||
|
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "10485760" {
|
||||||
|
t.Errorf("Default MaxRequestSizeHeader = %q, want %q", maxSize, "10485760")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSizeLimiterWithCustomSize(t *testing.T) {
|
||||||
|
limiter := NewRequestSizeLimiter(1024)
|
||||||
|
|
||||||
|
// Premium users get 10MB, regular users get 1KB
|
||||||
|
sizeFunc := func(r *http.Request) int64 {
|
||||||
|
if r.Header.Get("X-User-Tier") == "premium" {
|
||||||
|
return Size10MB
|
||||||
|
}
|
||||||
|
return 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := limiter.MiddlewareWithCustomSize(sizeFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Regular user with large request (should fail)
|
||||||
|
t.Run("RegularUserLargeRequest", func(t *testing.T) {
|
||||||
|
body := bytes.NewReader(make([]byte, 2048))
|
||||||
|
req := httptest.NewRequest("POST", "/test", body)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusRequestEntityTooLarge {
|
||||||
|
t.Errorf("Regular user large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Premium user with large request (should succeed)
|
||||||
|
t.Run("PremiumUserLargeRequest", func(t *testing.T) {
|
||||||
|
body := bytes.NewReader(make([]byte, 2048))
|
||||||
|
req := httptest.NewRequest("POST", "/test", body)
|
||||||
|
req.Header.Set("X-User-Tier", "premium")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Premium user large request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
231
pkg/server/shutdown_test.go
Normal file
231
pkg/server/shutdown_test.go
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGracefulServerTrackRequests(t *testing.T) {
|
||||||
|
srv := NewGracefulServer(Config{
|
||||||
|
Addr: ":0",
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||||
|
|
||||||
|
// Start some requests
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit for requests to start
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check in-flight count
|
||||||
|
inFlight := srv.InFlightRequests()
|
||||||
|
if inFlight == 0 {
|
||||||
|
t.Error("Should have in-flight requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all requests to complete
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Check that counter is back to zero
|
||||||
|
inFlight = srv.InFlightRequests()
|
||||||
|
if inFlight != 0 {
|
||||||
|
t.Errorf("In-flight requests should be 0, got %d", inFlight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
|
||||||
|
srv := NewGracefulServer(Config{
|
||||||
|
Addr: ":0",
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||||
|
|
||||||
|
// Mark as shutting down
|
||||||
|
srv.isShuttingDown.Store(true)
|
||||||
|
|
||||||
|
// Try to make a request
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Should get 503
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("Expected 503, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthCheckHandler(t *testing.T) {
|
||||||
|
srv := NewGracefulServer(Config{
|
||||||
|
Addr: ":0",
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := srv.HealthCheckHandler()
|
||||||
|
|
||||||
|
// Healthy
|
||||||
|
t.Run("Healthy", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Body.String() != `{"status":"healthy"}` {
|
||||||
|
t.Errorf("Unexpected body: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Shutting down
|
||||||
|
t.Run("ShuttingDown", func(t *testing.T) {
|
||||||
|
srv.isShuttingDown.Store(true)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("Expected 503, got %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadinessHandler(t *testing.T) {
|
||||||
|
srv := NewGracefulServer(Config{
|
||||||
|
Addr: ":0",
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := srv.ReadinessHandler()
|
||||||
|
|
||||||
|
// Ready with no in-flight requests
|
||||||
|
t.Run("Ready", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/ready", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if body != `{"ready":true,"in_flight_requests":0}` {
|
||||||
|
t.Errorf("Unexpected body: %s", body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Not ready during shutdown
|
||||||
|
t.Run("NotReady", func(t *testing.T) {
|
||||||
|
srv.isShuttingDown.Store(true)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/ready", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("Expected 503, got %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownCallbacks(t *testing.T) {
|
||||||
|
callbackExecuted := false
|
||||||
|
|
||||||
|
RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
callbackExecuted = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := executeShutdownCallbacks(ctx)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("executeShutdownCallbacks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !callbackExecuted {
|
||||||
|
t.Error("Shutdown callback was not executed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset for other tests
|
||||||
|
shutdownCallbacks = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDrainRequests(t *testing.T) {
|
||||||
|
srv := NewGracefulServer(Config{
|
||||||
|
Addr: ":0",
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||||
|
DrainTimeout: 1 * time.Second,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate in-flight requests
|
||||||
|
srv.inFlightRequests.Add(3)
|
||||||
|
|
||||||
|
// Start draining in background
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// Simulate requests completing
|
||||||
|
srv.inFlightRequests.Add(-3)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := srv.drainRequests(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("drainRequests() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.InFlightRequests() != 0 {
|
||||||
|
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDrainRequestsTimeout(t *testing.T) {
|
||||||
|
srv := NewGracefulServer(Config{
|
||||||
|
Addr: ":0",
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||||
|
DrainTimeout: 100 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate in-flight requests that don't complete
|
||||||
|
srv.inFlightRequests.Add(5)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := srv.drainRequests(ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("drainRequests() should timeout with error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
srv.inFlightRequests.Add(-5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClientIP(t *testing.T) {
|
||||||
|
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
|
||||||
|
// Including here for completeness of server tests
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user