package middleware import ( "net/http" "sync" "time" "golang.org/x/time/rate" ) // RateLimiter provides rate limiting functionality type RateLimiter struct { mu sync.RWMutex limiters map[string]*rate.Limiter rate rate.Limit burst int cleanup time.Duration } // NewRateLimiter creates a new rate limiter // rps is requests per second, burst is the maximum burst size func NewRateLimiter(rps float64, burst int) *RateLimiter { rl := &RateLimiter{ limiters: make(map[string]*rate.Limiter), rate: rate.Limit(rps), burst: burst, cleanup: 5 * time.Minute, // Clean up stale limiters every 5 minutes } // Start cleanup goroutine go rl.cleanupRoutine() return rl } // getLimiter returns the rate limiter for a given key (e.g., IP address) func (rl *RateLimiter) getLimiter(key string) *rate.Limiter { rl.mu.RLock() limiter, exists := rl.limiters[key] rl.mu.RUnlock() if exists { return limiter } rl.mu.Lock() defer rl.mu.Unlock() // Double-check after acquiring write lock if limiter, exists := rl.limiters[key]; exists { return limiter } limiter = rate.NewLimiter(rl.rate, rl.burst) rl.limiters[key] = limiter return limiter } // cleanupRoutine periodically removes inactive limiters func (rl *RateLimiter) cleanupRoutine() { ticker := time.NewTicker(rl.cleanup) defer ticker.Stop() for range ticker.C { rl.mu.Lock() // Simple cleanup: remove all limiters // In production, you might want to track last access time rl.limiters = make(map[string]*rate.Limiter) rl.mu.Unlock() } } // Middleware returns an HTTP middleware that applies rate limiting func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Use IP address as the rate limit key // In production, you might want to use X-Forwarded-For or custom headers key := r.RemoteAddr limiter := rl.getLimiter(key) if !limiter.Allow() { http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // MiddlewareWithKeyFunc returns an HTTP middleware with a custom key extraction function func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key := keyFunc(r) if key == "" { key = r.RemoteAddr } limiter := rl.getLimiter(key) if !limiter.Allow() { http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } }