ResolveSpec/pkg/server/shutdown_test.go
2025-12-08 09:26:36 +02:00

232 lines
5.3 KiB
Go

package server
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
)
func TestGracefulServerTrackRequests(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}),
})
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
// Start some requests
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}()
}
// Wait a bit for requests to start
time.Sleep(10 * time.Millisecond)
// Check in-flight count
inFlight := srv.InFlightRequests()
if inFlight == 0 {
t.Error("Should have in-flight requests")
}
// Wait for all requests to complete
wg.Wait()
// Check that counter is back to zero
inFlight = srv.InFlightRequests()
if inFlight != 0 {
t.Errorf("In-flight requests should be 0, got %d", inFlight)
}
}
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
})
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
// Mark as shutting down
srv.isShuttingDown.Store(true)
// Try to make a request
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Should get 503
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
}
func TestHealthCheckHandler(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
})
handler := srv.HealthCheckHandler()
// Healthy
t.Run("Healthy", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
if w.Body.String() != `{"status":"healthy"}` {
t.Errorf("Unexpected body: %s", w.Body.String())
}
})
// Shutting down
t.Run("ShuttingDown", func(t *testing.T) {
srv.isShuttingDown.Store(true)
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
})
}
func TestReadinessHandler(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
})
handler := srv.ReadinessHandler()
// Ready with no in-flight requests
t.Run("Ready", func(t *testing.T) {
req := httptest.NewRequest("GET", "/ready", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
body := w.Body.String()
if body != `{"ready":true,"in_flight_requests":0}` {
t.Errorf("Unexpected body: %s", body)
}
})
// Not ready during shutdown
t.Run("NotReady", func(t *testing.T) {
srv.isShuttingDown.Store(true)
req := httptest.NewRequest("GET", "/ready", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
})
}
func TestShutdownCallbacks(t *testing.T) {
callbackExecuted := false
RegisterShutdownCallback(func(ctx context.Context) error {
callbackExecuted = true
return nil
})
ctx := context.Background()
err := executeShutdownCallbacks(ctx)
if err != nil {
t.Errorf("executeShutdownCallbacks() error = %v", err)
}
if !callbackExecuted {
t.Error("Shutdown callback was not executed")
}
// Reset for other tests
shutdownCallbacks = nil
}
func TestDrainRequests(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
DrainTimeout: 1 * time.Second,
})
// Simulate in-flight requests
srv.inFlightRequests.Add(3)
// Start draining in background
go func() {
time.Sleep(100 * time.Millisecond)
// Simulate requests completing
srv.inFlightRequests.Add(-3)
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := srv.drainRequests(ctx)
if err != nil {
t.Errorf("drainRequests() error = %v", err)
}
if srv.InFlightRequests() != 0 {
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
}
}
func TestDrainRequestsTimeout(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
DrainTimeout: 100 * time.Millisecond,
})
// Simulate in-flight requests that don't complete
srv.inFlightRequests.Add(5)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := srv.drainRequests(ctx)
if err == nil {
t.Error("drainRequests() should timeout with error")
}
// Cleanup
srv.inFlightRequests.Add(-5)
}
func TestGetClientIP(t *testing.T) {
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
// Including here for completeness of server tests
}