mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-05 11:24:26 +00:00
Added blacklist middleware
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -71,11 +73,11 @@ func (rl *RateLimiter) cleanupRoutine() {
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that applies rate limiting
|
||||
// Automatically handles X-Forwarded-For headers when behind a proxy
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Use IP address as the rate limit key
|
||||
// In production, you might want to use X-Forwarded-For or custom headers
|
||||
key := r.RemoteAddr
|
||||
// Extract client IP, handling proxy headers
|
||||
key := getClientIP(r)
|
||||
|
||||
limiter := rl.getLimiter(key)
|
||||
|
||||
@@ -108,3 +110,115 @@ func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitInfo contains information about a specific IP's rate limit status
|
||||
type RateLimitInfo struct {
|
||||
IP string `json:"ip"`
|
||||
TokensRemaining float64 `json:"tokens_remaining"`
|
||||
Limit float64 `json:"limit"`
|
||||
Burst int `json:"burst"`
|
||||
}
|
||||
|
||||
// GetTrackedIPs returns all IPs currently being tracked by the rate limiter
|
||||
func (rl *RateLimiter) GetTrackedIPs() []string {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
ips := make([]string, 0, len(rl.limiters))
|
||||
for ip := range rl.limiters {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// GetRateLimitInfo returns rate limit information for a specific IP
|
||||
func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.limiters[ip]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Return default info for untracked IP
|
||||
return &RateLimitInfo{
|
||||
IP: ip,
|
||||
TokensRemaining: float64(rl.burst),
|
||||
Limit: float64(rl.rate),
|
||||
Burst: rl.burst,
|
||||
}
|
||||
}
|
||||
|
||||
return &RateLimitInfo{
|
||||
IP: ip,
|
||||
TokensRemaining: limiter.Tokens(),
|
||||
Limit: float64(rl.rate),
|
||||
Burst: rl.burst,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllRateLimitInfo returns rate limit information for all tracked IPs
|
||||
func (rl *RateLimiter) GetAllRateLimitInfo() []*RateLimitInfo {
|
||||
ips := rl.GetTrackedIPs()
|
||||
info := make([]*RateLimitInfo, 0, len(ips))
|
||||
|
||||
for _, ip := range ips {
|
||||
info = append(info, rl.GetRateLimitInfo(ip))
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// StatsHandler returns an HTTP handler that exposes rate limit statistics
|
||||
// Example: GET /rate-limit-stats
|
||||
func (rl *RateLimiter) StatsHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Support querying specific IP via ?ip=x.x.x.x
|
||||
if ip := r.URL.Query().Get("ip"); ip != "" {
|
||||
info := rl.GetRateLimitInfo(ip)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(info)
|
||||
return
|
||||
}
|
||||
|
||||
// Return all tracked IPs
|
||||
allInfo := rl.GetAllRateLimitInfo()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_tracked_ips": len(allInfo),
|
||||
"rate_limit_config": map[string]interface{}{
|
||||
"requests_per_second": float64(rl.rate),
|
||||
"burst": rl.burst,
|
||||
},
|
||||
"tracked_ips": allInfo,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP extracts the real client IP from the request
|
||||
// Handles X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (most common in production)
|
||||
// Format: X-Forwarded-For: client, proxy1, proxy2
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP (the original client)
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Check X-Real-IP header (used by some proxies like nginx)
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
// Remove port if present (format: "ip:port")
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user