diff --git a/pkg/middleware/blacklist.go b/pkg/middleware/blacklist.go index 274f4b7..9c9e243 100644 --- a/pkg/middleware/blacklist.go +++ b/pkg/middleware/blacklist.go @@ -6,15 +6,17 @@ import ( "net/http" "strings" "sync" + + "github.com/bitechdev/ResolveSpec/pkg/logger" ) // IPBlacklist provides IP blocking functionality type IPBlacklist struct { - mu sync.RWMutex - ips map[string]bool // Individual IPs - cidrs []*net.IPNet // CIDR ranges - reason map[string]string - useProxy bool // Whether to check X-Forwarded-For headers + mu sync.RWMutex + ips map[string]bool // Individual IPs + cidrs []*net.IPNet // CIDR ranges + reason map[string]string + useProxy bool // Whether to check X-Forwarded-For headers } // BlacklistConfig configures the IP blacklist @@ -92,7 +94,7 @@ func (bl *IPBlacklist) UnblockCIDR(cidr string) { } // IsBlocked checks if an IP is blacklisted -func (bl *IPBlacklist) IsBlocked(ip string) (bool, string) { +func (bl *IPBlacklist) IsBlocked(ip string) (blacklist bool, reason string) { bl.mu.RLock() defer bl.mu.RUnlock() @@ -178,7 +180,10 @@ func (bl *IPBlacklist) Middleware(next http.Handler) http.Handler { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) - json.NewEncoder(w).Encode(response) + err := json.NewEncoder(w).Encode(response) + if err != nil { + logger.Debug("Failed to write blacklist response: %v", err) + } return } @@ -192,13 +197,16 @@ func (bl *IPBlacklist) StatsHandler() http.Handler { ips, cidrs := bl.GetBlacklist() stats := map[string]interface{}{ - "blocked_ips": ips, - "blocked_cidrs": cidrs, - "total_ips": len(ips), - "total_cidrs": len(cidrs), + "blocked_ips": ips, + "blocked_cidrs": cidrs, + "total_ips": len(ips), + "total_cidrs": len(cidrs), } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(stats) + err := json.NewEncoder(w).Encode(stats) + if err != nil { + logger.Debug("Failed to encode stats: %v", err) + } }) } diff --git a/pkg/middleware/ratelimit.go b/pkg/middleware/ratelimit.go index ac10314..6e2d0ac 100644 --- a/pkg/middleware/ratelimit.go +++ b/pkg/middleware/ratelimit.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/bitechdev/ResolveSpec/pkg/logger" "golang.org/x/time/rate" ) @@ -113,10 +114,10 @@ func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string) // RateLimitInfo contains information about a specific IP's rate limit status type RateLimitInfo struct { - IP string `json:"ip"` - TokensRemaining float64 `json:"tokens_remaining"` - Limit float64 `json:"limit"` - Burst int `json:"burst"` + IP string `json:"ip"` + TokensRemaining float64 `json:"tokens_remaining"` + Limit float64 `json:"limit"` + Burst int `json:"burst"` } // GetTrackedIPs returns all IPs currently being tracked by the rate limiter @@ -140,18 +141,18 @@ func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo { if !exists { // Return default info for untracked IP return &RateLimitInfo{ - IP: ip, + IP: ip, TokensRemaining: float64(rl.burst), - Limit: float64(rl.rate), - Burst: rl.burst, + Limit: float64(rl.rate), + Burst: rl.burst, } } return &RateLimitInfo{ - IP: ip, + IP: ip, TokensRemaining: limiter.Tokens(), - Limit: float64(rl.rate), - Burst: rl.burst, + Limit: float64(rl.rate), + Burst: rl.burst, } } @@ -175,7 +176,10 @@ func (rl *RateLimiter) StatsHandler() http.Handler { if ip := r.URL.Query().Get("ip"); ip != "" { info := rl.GetRateLimitInfo(ip) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(info) + err := json.NewEncoder(w).Encode(info) + if err != nil { + logger.Debug("Failed to encode json: %v", err) + } return } @@ -186,13 +190,16 @@ func (rl *RateLimiter) StatsHandler() http.Handler { "total_tracked_ips": len(allInfo), "rate_limit_config": map[string]interface{}{ "requests_per_second": float64(rl.rate), - "burst": rl.burst, + "burst": rl.burst, }, "tracked_ips": allInfo, } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(stats) + err := json.NewEncoder(w).Encode(stats) + if err != nil { + logger.Debug("Failed to encode json: %v", err) + } }) }