mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
111 lines
2.7 KiB
Go
111 lines
2.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// RateLimiter provides rate limiting functionality
|
|
type RateLimiter struct {
|
|
mu sync.RWMutex
|
|
limiters map[string]*rate.Limiter
|
|
rate rate.Limit
|
|
burst int
|
|
cleanup time.Duration
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter
|
|
// rps is requests per second, burst is the maximum burst size
|
|
func NewRateLimiter(rps float64, burst int) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
limiters: make(map[string]*rate.Limiter),
|
|
rate: rate.Limit(rps),
|
|
burst: burst,
|
|
cleanup: 5 * time.Minute, // Clean up stale limiters every 5 minutes
|
|
}
|
|
|
|
// Start cleanup goroutine
|
|
go rl.cleanupRoutine()
|
|
|
|
return rl
|
|
}
|
|
|
|
// getLimiter returns the rate limiter for a given key (e.g., IP address)
|
|
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
|
|
rl.mu.RLock()
|
|
limiter, exists := rl.limiters[key]
|
|
rl.mu.RUnlock()
|
|
|
|
if exists {
|
|
return limiter
|
|
}
|
|
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
// Double-check after acquiring write lock
|
|
if limiter, exists := rl.limiters[key]; exists {
|
|
return limiter
|
|
}
|
|
|
|
limiter = rate.NewLimiter(rl.rate, rl.burst)
|
|
rl.limiters[key] = limiter
|
|
return limiter
|
|
}
|
|
|
|
// cleanupRoutine periodically removes inactive limiters
|
|
func (rl *RateLimiter) cleanupRoutine() {
|
|
ticker := time.NewTicker(rl.cleanup)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
rl.mu.Lock()
|
|
// Simple cleanup: remove all limiters
|
|
// In production, you might want to track last access time
|
|
rl.limiters = make(map[string]*rate.Limiter)
|
|
rl.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Middleware returns an HTTP middleware that applies rate limiting
|
|
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Use IP address as the rate limit key
|
|
// In production, you might want to use X-Forwarded-For or custom headers
|
|
key := r.RemoteAddr
|
|
|
|
limiter := rl.getLimiter(key)
|
|
|
|
if !limiter.Allow() {
|
|
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// MiddlewareWithKeyFunc returns an HTTP middleware with a custom key extraction function
|
|
func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
key := keyFunc(r)
|
|
if key == "" {
|
|
key = r.RemoteAddr
|
|
}
|
|
|
|
limiter := rl.getLimiter(key)
|
|
|
|
if !limiter.Allow() {
|
|
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|