diff --git a/pkg/server/README.md b/pkg/server/README.md index 45342c9..6cd46b3 100644 --- a/pkg/server/README.md +++ b/pkg/server/README.md @@ -1,233 +1,314 @@ # Server Package -Graceful HTTP server with request draining and shutdown coordination. +Production-ready HTTP server manager with graceful shutdown, request draining, and comprehensive TLS/HTTPS support. + +## Features + +✅ **Multiple Server Management** - Run multiple HTTP/HTTPS servers concurrently +✅ **Graceful Shutdown** - Handles SIGINT/SIGTERM with request draining +✅ **Automatic Request Rejection** - New requests get 503 during shutdown +✅ **Health & Readiness Endpoints** - Kubernetes-ready health checks +✅ **Shutdown Callbacks** - Register cleanup functions (DB, cache, metrics) +✅ **Comprehensive TLS Support**: + - Certificate files (production) + - Self-signed certificates (development/testing) + - Let's Encrypt / AutoTLS (automatic certificate management) +✅ **GZIP Compression** - Optional response compression +✅ **Panic Recovery** - Automatic panic recovery middleware +✅ **Configurable Timeouts** - Read, write, idle, drain, and shutdown timeouts ## Quick Start +### Single Server + ```go import "github.com/bitechdev/ResolveSpec/pkg/server" -// Create server -srv := server.NewGracefulServer(server.Config{ - Addr: ":8080", - Handler: router, +// Create server manager +mgr := server.NewManager() + +// Add server +_, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: myRouter, + GZIP: true, }) -// Start server (blocks until shutdown signal) -if err := srv.ListenAndServe(); err != nil { +// Start and wait for shutdown signal +if err := mgr.ServeWithGracefulShutdown(); err != nil { log.Fatal(err) } ``` -## Features +### Multiple Servers -✅ Graceful shutdown on SIGINT/SIGTERM -✅ Request draining (waits for in-flight requests) -✅ Automatic request rejection during shutdown -✅ Health and readiness endpoints -✅ Shutdown callbacks for cleanup -✅ Configurable timeouts +```go +mgr := server.NewManager() + +// Public API +mgr.Add(server.Config{ + Name: "public-api", + Port: 8080, + Handler: publicRouter, +}) + +// Admin API +mgr.Add(server.Config{ + Name: "admin-api", + Port: 8081, + Handler: adminRouter, +}) + +// Start all and wait +mgr.ServeWithGracefulShutdown() +``` + +## HTTPS/TLS Configuration + +### Option 1: Certificate Files (Production) + +```go +mgr.Add(server.Config{ + Name: "https-server", + Host: "0.0.0.0", + Port: 443, + Handler: handler, + SSLCert: "/etc/ssl/certs/server.crt", + SSLKey: "/etc/ssl/private/server.key", +}) +``` + +### Option 2: Self-Signed Certificate (Development) + +```go +mgr.Add(server.Config{ + Name: "dev-server", + Host: "localhost", + Port: 8443, + Handler: handler, + SelfSignedSSL: true, // Auto-generates certificate +}) +``` + +### Option 3: Let's Encrypt / AutoTLS (Production) + +```go +mgr.Add(server.Config{ + Name: "prod-server", + Host: "0.0.0.0", + Port: 443, + Handler: handler, + AutoTLS: true, + AutoTLSDomains: []string{"example.com", "www.example.com"}, + AutoTLSEmail: "admin@example.com", + AutoTLSCacheDir: "./certs-cache", // Certificate cache directory +}) +``` ## Configuration ```go -config := server.Config{ - // Server address - Addr: ":8080", +server.Config{ + // Basic configuration + Name: "my-server", // Server name (required) + Host: "0.0.0.0", // Bind address + Port: 8080, // Port (required) + Handler: myRouter, // HTTP handler (required) + Description: "My API server", // Optional description - // HTTP handler - Handler: myRouter, + // Features + GZIP: true, // Enable GZIP compression - // Maximum time for graceful shutdown (default: 30s) - ShutdownTimeout: 30 * time.Second, + // TLS/HTTPS (choose one option) + SSLCert: "/path/to/cert.pem", // Certificate file + SSLKey: "/path/to/key.pem", // Key file + SelfSignedSSL: false, // Auto-generate self-signed cert + AutoTLS: false, // Let's Encrypt + AutoTLSDomains: []string{}, // Domains for AutoTLS + AutoTLSEmail: "", // Email for Let's Encrypt + AutoTLSCacheDir: "./certs-cache", // Cert cache directory - // Time to wait for in-flight requests (default: 25s) - DrainTimeout: 25 * time.Second, - - // Request read timeout (default: 10s) - ReadTimeout: 10 * time.Second, - - // Response write timeout (default: 10s) - WriteTimeout: 10 * time.Second, - - // Idle connection timeout (default: 120s) - IdleTimeout: 120 * time.Second, + // Timeouts + ShutdownTimeout: 30 * time.Second, // Max shutdown time + DrainTimeout: 25 * time.Second, // Request drain timeout + ReadTimeout: 15 * time.Second, // Request read timeout + WriteTimeout: 15 * time.Second, // Response write timeout + IdleTimeout: 60 * time.Second, // Idle connection timeout } - -srv := server.NewGracefulServer(config) ``` -## Shutdown Behavior +## Graceful Shutdown -**Signal received (SIGINT/SIGTERM):** +### Automatic (Recommended) -1. **Mark as shutting down** - New requests get 503 -2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests -3. **Shutdown server** - Close listeners and connections -4. **Execute callbacks** - Run registered cleanup functions +```go +mgr := server.NewManager() +// Add servers... + +// This blocks until SIGINT/SIGTERM +mgr.ServeWithGracefulShutdown() ``` -Time Event -───────────────────────────────────────── -0s Signal received: SIGTERM - ├─ Mark as shutting down - ├─ Reject new requests (503) - └─ Start draining... -1s In-flight: 50 requests -2s In-flight: 32 requests -3s In-flight: 12 requests -4s In-flight: 3 requests -5s In-flight: 0 requests ✓ - └─ All requests drained +### Manual Control -5s Execute shutdown callbacks -6s Shutdown complete +```go +mgr := server.NewManager() + +// Add and start servers +mgr.StartAll() + +// Later... stop gracefully +ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +defer cancel() + +if err := mgr.StopAllWithContext(ctx); err != nil { + log.Printf("Shutdown error: %v", err) +} +``` + +### Shutdown Callbacks + +Register cleanup functions to run during shutdown: + +```go +// Close database +mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Closing database...") + return db.Close() +}) + +// Flush metrics +mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Flushing metrics...") + return metrics.Flush(ctx) +}) + +// Close cache +mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Closing cache...") + return cache.Close() +}) ``` ## Health Checks -### Health Endpoint - -Returns 200 when healthy, 503 when shutting down: +### Adding Health Endpoints ```go -router.HandleFunc("/health", srv.HealthCheckHandler()) +instance, _ := mgr.Add(server.Config{ + Name: "api-server", + Port: 8080, + Handler: router, +}) + +// Add health endpoints to your router +router.HandleFunc("/health", instance.HealthCheckHandler()) +router.HandleFunc("/ready", instance.ReadinessHandler()) ``` -**Response (healthy):** +### Health Endpoint + +Returns server health status: + +**Healthy (200 OK):** ```json {"status":"healthy"} ``` -**Response (shutting down):** +**Shutting Down (503 Service Unavailable):** ```json {"status":"shutting_down"} ``` ### Readiness Endpoint -Includes in-flight request count: +Returns readiness with in-flight request count: -```go -router.HandleFunc("/ready", srv.ReadinessHandler()) -``` - -**Response:** +**Ready (200 OK):** ```json {"ready":true,"in_flight_requests":12} ``` -**During shutdown:** +**Not Ready (503 Service Unavailable):** ```json {"ready":false,"reason":"shutting_down"} ``` -## Shutdown Callbacks +## Shutdown Behavior -Register cleanup functions to run during shutdown: +When a shutdown signal (SIGINT/SIGTERM) is received: -```go -// Close database -server.RegisterShutdownCallback(func(ctx context.Context) error { - logger.Info("Closing database connection...") - return db.Close() -}) +1. **Mark as shutting down** → New requests get 503 +2. **Execute callbacks** → Run cleanup functions +3. **Drain requests** → Wait up to `DrainTimeout` for in-flight requests +4. **Shutdown servers** → Close listeners and connections -// Flush metrics -server.RegisterShutdownCallback(func(ctx context.Context) error { - logger.Info("Flushing metrics...") - return metricsProvider.Flush(ctx) -}) +``` +Time Event +───────────────────────────────────────── +0s Signal received: SIGTERM + ├─ Mark servers as shutting down + ├─ Reject new requests (503) + └─ Execute shutdown callbacks -// Close cache -server.RegisterShutdownCallback(func(ctx context.Context) error { - logger.Info("Closing cache...") - return cache.Close() -}) +1s Callbacks complete + └─ Start draining requests... + +2s In-flight: 50 requests +3s In-flight: 32 requests +4s In-flight: 12 requests +5s In-flight: 3 requests +6s In-flight: 0 requests ✓ + └─ All requests drained + +6s Shutdown servers +7s All servers stopped ✓ ``` -## Complete Example +## Server Management + +### Get Server Instance ```go -package main - -import ( - "context" - "log" - "net/http" - "time" - - "github.com/bitechdev/ResolveSpec/pkg/middleware" - "github.com/bitechdev/ResolveSpec/pkg/metrics" - "github.com/bitechdev/ResolveSpec/pkg/server" - "github.com/gorilla/mux" -) - -func main() { - // Initialize metrics - metricsProvider := metrics.NewPrometheusProvider() - metrics.SetProvider(metricsProvider) - - // Create router - router := mux.NewRouter() - - // Apply middleware - rateLimiter := middleware.NewRateLimiter(100, 20) - sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB) - sanitizer := middleware.DefaultSanitizer() - - router.Use(rateLimiter.Middleware) - router.Use(sizeLimiter.Middleware) - router.Use(sanitizer.Middleware) - router.Use(metricsProvider.Middleware) - - // API routes - router.HandleFunc("/api/data", dataHandler) - - // Create graceful server - srv := server.NewGracefulServer(server.Config{ - Addr: ":8080", - Handler: router, - ShutdownTimeout: 30 * time.Second, - DrainTimeout: 25 * time.Second, - }) - - // Health checks - router.HandleFunc("/health", srv.HealthCheckHandler()) - router.HandleFunc("/ready", srv.ReadinessHandler()) - - // Metrics endpoint - router.Handle("/metrics", metricsProvider.Handler()) - - // Register shutdown callbacks - server.RegisterShutdownCallback(func(ctx context.Context) error { - log.Println("Cleanup: Flushing metrics...") - return nil - }) - - server.RegisterShutdownCallback(func(ctx context.Context) error { - log.Println("Cleanup: Closing database...") - // return db.Close() - return nil - }) - - // Start server (blocks until shutdown) - log.Printf("Starting server on :8080") - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } - - // Wait for shutdown to complete - srv.Wait() - log.Println("Server stopped") +instance, err := mgr.Get("api-server") +if err != nil { + log.Fatal(err) } -func dataHandler(w http.ResponseWriter, r *http.Request) { - // Your handler logic - time.Sleep(100 * time.Millisecond) // Simulate work - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"message":"success"}`)) +// Check status +fmt.Printf("Address: %s\n", instance.Addr()) +fmt.Printf("Name: %s\n", instance.Name()) +fmt.Printf("In-flight: %d\n", instance.InFlightRequests()) +fmt.Printf("Shutting down: %v\n", instance.IsShuttingDown()) +``` + +### List All Servers + +```go +instances := mgr.List() +for _, instance := range instances { + fmt.Printf("Server: %s at %s\n", instance.Name(), instance.Addr()) +} +``` + +### Remove Server + +```go +// Stop and remove a server +if err := mgr.Remove("api-server"); err != nil { + log.Printf("Error removing server: %v", err) +} +``` + +### Restart All Servers + +```go +// Gracefully restart all servers +if err := mgr.RestartAll(); err != nil { + log.Printf("Error restarting: %v", err) } ``` @@ -250,23 +331,21 @@ spec: ports: - containerPort: 8080 - # Liveness probe - is app running? + # Liveness probe livenessProbe: httpGet: path: /health port: 8080 initialDelaySeconds: 10 periodSeconds: 10 - timeoutSeconds: 5 - # Readiness probe - can app handle traffic? + # Readiness probe readinessProbe: httpGet: path: /ready port: 8080 initialDelaySeconds: 5 periodSeconds: 5 - timeoutSeconds: 3 # Graceful shutdown lifecycle: @@ -274,26 +353,12 @@ spec: exec: command: ["/bin/sh", "-c", "sleep 5"] - # Environment env: - name: SHUTDOWN_TIMEOUT value: "30" -``` -### Service - -```yaml -apiVersion: v1 -kind: Service -metadata: - name: myapp -spec: - selector: - app: myapp - ports: - - port: 80 - targetPort: 8080 - type: LoadBalancer + # Allow time for graceful shutdown + terminationGracePeriodSeconds: 35 ``` ## Docker Compose @@ -312,8 +377,70 @@ services: interval: 10s timeout: 5s retries: 3 - start_period: 10s - stop_grace_period: 35s # Slightly longer than shutdown timeout + stop_grace_period: 35s +``` + +## Complete Example + +```go +package main + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/server" +) + +func main() { + // Create server manager + mgr := server.NewManager() + + // Register shutdown callbacks + mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Cleanup: Closing database...") + // return db.Close() + return nil + }) + + // Create router + router := http.NewServeMux() + router.HandleFunc("/api/data", dataHandler) + + // Add server + instance, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "0.0.0.0", + Port: 8080, + Handler: router, + GZIP: true, + ShutdownTimeout: 30 * time.Second, + DrainTimeout: 25 * time.Second, + }) + if err != nil { + log.Fatal(err) + } + + // Add health endpoints + router.HandleFunc("/health", instance.HealthCheckHandler()) + router.HandleFunc("/ready", instance.ReadinessHandler()) + + // Start and wait for shutdown + log.Println("Starting server on :8080") + if err := mgr.ServeWithGracefulShutdown(); err != nil { + log.Printf("Server stopped: %v", err) + } + + log.Println("Server shutdown complete") +} + +func dataHandler(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) // Simulate work + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"message":"success"}`)) +} ``` ## Testing Graceful Shutdown @@ -330,7 +457,7 @@ SERVER_PID=$! # Wait for server to start sleep 2 -# Send some requests +# Send requests for i in {1..10}; do curl http://localhost:8080/api/data & done @@ -341,7 +468,7 @@ sleep 1 # Send shutdown signal kill -TERM $SERVER_PID -# Try to send more requests (should get 503) +# Try more requests (should get 503) curl -v http://localhost:8080/api/data # Wait for server to stop @@ -349,101 +476,13 @@ wait $SERVER_PID echo "Server stopped gracefully" ``` -### Expected Output - -``` -Starting server on :8080 -Received signal: terminated, initiating graceful shutdown -Starting graceful shutdown... -Waiting for 8 in-flight requests to complete... -Waiting for 4 in-flight requests to complete... -Waiting for 1 in-flight requests to complete... -All requests drained in 2.3s -Cleanup: Flushing metrics... -Cleanup: Closing database... -Shutting down HTTP server... -Graceful shutdown complete -Server stopped -``` - -## Monitoring In-Flight Requests - -```go -// Get current in-flight count -count := srv.InFlightRequests() -fmt.Printf("In-flight requests: %d\n", count) - -// Check if shutting down -if srv.IsShuttingDown() { - fmt.Println("Server is shutting down") -} -``` - -## Advanced Usage - -### Custom Shutdown Logic - -```go -// Implement custom shutdown -go func() { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - - <-sigChan - log.Println("Shutdown signal received") - - // Custom pre-shutdown logic - log.Println("Running custom cleanup...") - - // Shutdown with callbacks - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := srv.ShutdownWithCallbacks(ctx); err != nil { - log.Printf("Shutdown error: %v", err) - } -}() - -// Start server -srv.server.ListenAndServe() -``` - -### Multiple Servers - -```go -// HTTP server -httpSrv := server.NewGracefulServer(server.Config{ - Addr: ":8080", - Handler: httpRouter, -}) - -// HTTPS server -httpsSrv := server.NewGracefulServer(server.Config{ - Addr: ":8443", - Handler: httpsRouter, -}) - -// Start both -go httpSrv.ListenAndServe() -go httpsSrv.ListenAndServe() - -// Shutdown both on signal -sigChan := make(chan os.Signal, 1) -signal.Notify(sigChan, os.Interrupt) -<-sigChan - -ctx := context.Background() -httpSrv.Shutdown(ctx) -httpsSrv.Shutdown(ctx) -``` - ## Best Practices 1. **Set appropriate timeouts** - `DrainTimeout` < `ShutdownTimeout` - `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds` -2. **Register cleanup callbacks** for: +2. **Use shutdown callbacks** for: - Database connections - Message queues - Metrics flushing @@ -458,7 +497,12 @@ httpsSrv.Shutdown(ctx) - Set `preStop` hook in Kubernetes (5-10s delay) - Allows load balancer to deregister before shutdown -5. **Monitoring** +5. **HTTPS in production** + - Use AutoTLS for public-facing services + - Use certificate files for enterprise PKI + - Use self-signed only for development/testing + +6. **Monitoring** - Track in-flight requests in metrics - Alert on slow drains - Monitor shutdown duration @@ -470,24 +514,63 @@ httpsSrv.Shutdown(ctx) ```go // Increase drain timeout config.DrainTimeout = 60 * time.Second +config.ShutdownTimeout = 65 * time.Second ``` -### Requests Still Timing Out +### Requests Timing Out ```go // Increase write timeout config.WriteTimeout = 30 * time.Second ``` -### Force Shutdown Not Working - -The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed. - -### Debugging Shutdown +### Certificate Issues + +```go +// Verify certificate files exist and are readable +if _, err := os.Stat(config.SSLCert); err != nil { + log.Fatalf("Certificate not found: %v", err) +} + +// For AutoTLS, ensure: +// - Port 443 is accessible +// - Domains resolve to server IP +// - Cache directory is writable +``` + +### Debug Logging ```go -// Enable debug logging import "github.com/bitechdev/ResolveSpec/pkg/logger" +// Enable debug logging logger.SetLevel("debug") ``` + +## API Reference + +### Manager Methods + +- `NewManager()` - Create new server manager +- `Add(cfg Config)` - Register server instance +- `Get(name string)` - Get server by name +- `Remove(name string)` - Stop and remove server +- `StartAll()` - Start all registered servers +- `StopAll()` - Stop all servers gracefully +- `StopAllWithContext(ctx)` - Stop with timeout +- `RestartAll()` - Restart all servers +- `List()` - Get all server instances +- `ServeWithGracefulShutdown()` - Start and block until shutdown +- `RegisterShutdownCallback(cb)` - Register cleanup function + +### Instance Methods + +- `Start()` - Start the server +- `Stop(ctx)` - Stop gracefully +- `Addr()` - Get server address +- `Name()` - Get server name +- `HealthCheckHandler()` - Get health handler +- `ReadinessHandler()` - Get readiness handler +- `InFlightRequests()` - Get in-flight count +- `IsShuttingDown()` - Check shutdown status +- `Wait()` - Block until shutdown complete diff --git a/pkg/server/example_test.go b/pkg/server/example_test.go new file mode 100644 index 0000000..032b6f0 --- /dev/null +++ b/pkg/server/example_test.go @@ -0,0 +1,294 @@ +package server_test + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/server" +) + +// ExampleManager_basic demonstrates basic server manager usage +func ExampleManager_basic() { + // Create a server manager + mgr := server.NewManager() + + // Define a simple handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "Hello from server!") + }) + + // Add an HTTP server + _, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: handler, + GZIP: true, // Enable GZIP compression + }) + if err != nil { + panic(err) + } + + // Start all servers + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Server is now running... + // When done, stop gracefully + if err := mgr.StopAll(); err != nil { + panic(err) + } +} + +// ExampleManager_https demonstrates HTTPS configurations +func ExampleManager_https() { + mgr := server.NewManager() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Secure connection!") + }) + + // Option 1: Use certificate files + _, err := mgr.Add(server.Config{ + Name: "https-server-files", + Host: "localhost", + Port: 8443, + Handler: handler, + SSLCert: "/path/to/cert.pem", + SSLKey: "/path/to/key.pem", + }) + if err != nil { + panic(err) + } + + // Option 2: Self-signed certificate (for development) + _, err = mgr.Add(server.Config{ + Name: "https-server-self-signed", + Host: "localhost", + Port: 8444, + Handler: handler, + SelfSignedSSL: true, + }) + if err != nil { + panic(err) + } + + // Option 3: Let's Encrypt / AutoTLS (for production) + _, err = mgr.Add(server.Config{ + Name: "https-server-letsencrypt", + Host: "0.0.0.0", + Port: 443, + Handler: handler, + AutoTLS: true, + AutoTLSDomains: []string{"example.com", "www.example.com"}, + AutoTLSEmail: "admin@example.com", + AutoTLSCacheDir: "./certs-cache", + }) + if err != nil { + panic(err) + } + + // Start all servers + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Cleanup + mgr.StopAll() +} + +// ExampleManager_gracefulShutdown demonstrates graceful shutdown with callbacks +func ExampleManager_gracefulShutdown() { + mgr := server.NewManager() + + // Register shutdown callbacks for cleanup tasks + mgr.RegisterShutdownCallback(func(ctx context.Context) error { + fmt.Println("Closing database connections...") + // Close your database here + return nil + }) + + mgr.RegisterShutdownCallback(func(ctx context.Context) error { + fmt.Println("Flushing metrics...") + // Flush metrics here + return nil + }) + + // Add server with custom timeouts + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate some work + time.Sleep(100 * time.Millisecond) + fmt.Fprintln(w, "Done!") + }) + + _, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: handler, + ShutdownTimeout: 30 * time.Second, // Max time for shutdown + DrainTimeout: 25 * time.Second, // Time to wait for in-flight requests + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + }) + if err != nil { + panic(err) + } + + // Start servers and block until shutdown signal (SIGINT/SIGTERM) + // This will automatically handle graceful shutdown with callbacks + if err := mgr.ServeWithGracefulShutdown(); err != nil { + fmt.Printf("Shutdown completed: %v\n", err) + } +} + +// ExampleManager_healthChecks demonstrates health and readiness endpoints +func ExampleManager_healthChecks() { + mgr := server.NewManager() + + // Create a router with health endpoints + mux := http.NewServeMux() + mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Data endpoint") + }) + + // Add server + instance, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: mux, + }) + if err != nil { + panic(err) + } + + // Add health and readiness endpoints + mux.HandleFunc("/health", instance.HealthCheckHandler()) + mux.HandleFunc("/ready", instance.ReadinessHandler()) + + // Start the server + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Health check returns: + // - 200 OK with {"status":"healthy"} when healthy + // - 503 Service Unavailable with {"status":"shutting_down"} when shutting down + + // Readiness check returns: + // - 200 OK with {"ready":true,"in_flight_requests":N} when ready + // - 503 Service Unavailable with {"ready":false,"reason":"shutting_down"} when shutting down + + // Cleanup + mgr.StopAll() +} + +// ExampleManager_multipleServers demonstrates running multiple servers +func ExampleManager_multipleServers() { + mgr := server.NewManager() + + // Public API server + publicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Public API") + }) + _, err := mgr.Add(server.Config{ + Name: "public-api", + Host: "0.0.0.0", + Port: 8080, + Handler: publicHandler, + GZIP: true, + }) + if err != nil { + panic(err) + } + + // Admin API server (different port) + adminHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Admin API") + }) + _, err = mgr.Add(server.Config{ + Name: "admin-api", + Host: "localhost", + Port: 8081, + Handler: adminHandler, + }) + if err != nil { + panic(err) + } + + // Metrics server (internal only) + metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Metrics data") + }) + _, err = mgr.Add(server.Config{ + Name: "metrics", + Host: "127.0.0.1", + Port: 9090, + Handler: metricsHandler, + }) + if err != nil { + panic(err) + } + + // Start all servers at once + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Get specific server instance + publicInstance, err := mgr.Get("public-api") + if err != nil { + panic(err) + } + fmt.Printf("Public API running on: %s\n", publicInstance.Addr()) + + // List all servers + instances := mgr.List() + fmt.Printf("Running %d servers\n", len(instances)) + + // Stop all servers gracefully (in parallel) + if err := mgr.StopAll(); err != nil { + panic(err) + } +} + +// ExampleManager_monitoring demonstrates monitoring server state +func ExampleManager_monitoring() { + mgr := server.NewManager() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) // Simulate work + fmt.Fprintln(w, "Done") + }) + + instance, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: handler, + }) + if err != nil { + panic(err) + } + + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Check server status + fmt.Printf("Server address: %s\n", instance.Addr()) + fmt.Printf("Server name: %s\n", instance.Name()) + fmt.Printf("Is shutting down: %v\n", instance.IsShuttingDown()) + fmt.Printf("In-flight requests: %d\n", instance.InFlightRequests()) + + // Cleanup + mgr.StopAll() + + // Wait for complete shutdown + instance.Wait() +} diff --git a/pkg/server/interfaces.go b/pkg/server/interfaces.go index 892cb58..633c394 100644 --- a/pkg/server/interfaces.go +++ b/pkg/server/interfaces.go @@ -3,6 +3,7 @@ package server import ( "context" "net/http" + "time" ) // Config holds the configuration for a single web server instance. @@ -11,11 +12,52 @@ type Config struct { Host string Port int Description string - SSLCert string - SSLKey string - GZIP bool + // Handler is the http.Handler (e.g., a router) to be served. Handler http.Handler + + // GZIP compression support + GZIP bool + + // TLS/HTTPS configuration options (mutually exclusive) + // Option 1: Provide certificate and key files directly + SSLCert string + SSLKey string + + // Option 2: Use self-signed certificate (for development/testing) + // Generates a self-signed certificate automatically if no SSLCert/SSLKey provided + SelfSignedSSL bool + + // Option 3: Use Let's Encrypt / Certbot for automatic TLS + // AutoTLS enables automatic certificate management via Let's Encrypt + AutoTLS bool + // AutoTLSDomains specifies the domains for Let's Encrypt certificates + AutoTLSDomains []string + // AutoTLSCacheDir specifies where to cache certificates (default: "./certs-cache") + AutoTLSCacheDir string + // AutoTLSEmail is the email for Let's Encrypt registration (optional but recommended) + AutoTLSEmail string + + // Graceful shutdown configuration + // ShutdownTimeout is the maximum time to wait for graceful shutdown + // Default: 30 seconds + ShutdownTimeout time.Duration + + // DrainTimeout is the time to wait for in-flight requests to complete + // before forcing shutdown. Default: 25 seconds + DrainTimeout time.Duration + + // ReadTimeout is the maximum duration for reading the entire request + // Default: 15 seconds + ReadTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out writes of the response + // Default: 15 seconds + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the next request + // Default: 60 seconds + IdleTimeout time.Duration } // Instance defines the interface for a single server instance. @@ -24,11 +66,33 @@ type Instance interface { // Start begins serving requests. This method should be non-blocking and // run the server in a separate goroutine. Start() error + // Stop gracefully shuts down the server without interrupting any active connections. // It accepts a context to allow for a timeout. Stop(ctx context.Context) error + // Addr returns the network address the server is listening on. Addr() string + + // Name returns the server instance name. + Name() string + + // HealthCheckHandler returns a handler that responds to health checks. + // Returns 200 OK when healthy, 503 Service Unavailable when shutting down. + HealthCheckHandler() http.HandlerFunc + + // ReadinessHandler returns a handler for readiness checks. + // Includes in-flight request count. + ReadinessHandler() http.HandlerFunc + + // InFlightRequests returns the current number of in-flight requests. + InFlightRequests() int64 + + // IsShuttingDown returns true if the server is shutting down. + IsShuttingDown() bool + + // Wait blocks until shutdown is complete. + Wait() } // Manager defines the interface for a server manager. @@ -37,16 +101,37 @@ type Manager interface { // Add registers a new server instance based on the provided configuration. // The server is not started until StartAll or Start is called on the instance. Add(cfg Config) (Instance, error) + // Get returns a server instance by its name. Get(name string) (Instance, error) + // Remove stops and removes a server instance by its name. Remove(name string) error + // StartAll starts all registered server instances that are not already running. StartAll() error + // StopAll gracefully shuts down all running server instances. + // Executes shutdown callbacks and drains in-flight requests. StopAll() error + + // StopAllWithContext gracefully shuts down all running server instances with a context. + StopAllWithContext(ctx context.Context) error + // RestartAll gracefully restarts all running server instances. RestartAll() error + // List returns all registered server instances. List() []Instance + + // ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received. + // It handles SIGINT and SIGTERM signals and performs graceful shutdown with callbacks. + ServeWithGracefulShutdown() error + + // RegisterShutdownCallback registers a callback to be called during shutdown. + // Useful for cleanup tasks like closing database connections, flushing metrics, etc. + RegisterShutdownCallback(cb ShutdownCallback) } + +// ShutdownCallback is a function called during graceful shutdown. +type ShutdownCallback func(context.Context) error diff --git a/pkg/server/manager.go b/pkg/server/manager.go index 005d8fa..02f0c3a 100644 --- a/pkg/server/manager.go +++ b/pkg/server/manager.go @@ -4,26 +4,173 @@ import ( "context" "crypto/tls" "fmt" + "net" "net/http" + "os" + "os/signal" "sync" + "sync/atomic" + "syscall" "time" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/middleware" "github.com/klauspost/compress/gzhttp" - "golang.org/x/net/http2" ) -// serverManager manages a collection of server instances. +// gracefulServer wraps http.Server with graceful shutdown capabilities (internal type) +type gracefulServer struct { + server *http.Server + shutdownTimeout time.Duration + drainTimeout time.Duration + inFlightRequests atomic.Int64 + isShuttingDown atomic.Bool + shutdownOnce sync.Once + shutdownComplete chan struct{} +} + +// trackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown +func (gs *gracefulServer) trackRequestsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if shutting down + if gs.isShuttingDown.Load() { + http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable) + return + } + + // Increment in-flight counter + gs.inFlightRequests.Add(1) + defer gs.inFlightRequests.Add(-1) + + // Serve the request + next.ServeHTTP(w, r) + }) +} + +// shutdown performs graceful shutdown with request draining +func (gs *gracefulServer) shutdown(ctx context.Context) error { + var shutdownErr error + + gs.shutdownOnce.Do(func() { + logger.Info("Starting graceful shutdown...") + + // Mark as shutting down (new requests will be rejected) + gs.isShuttingDown.Store(true) + + // Create context with timeout + shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout) + defer cancel() + + // Wait for in-flight requests to complete (with drain timeout) + drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout) + defer drainCancel() + + shutdownErr = gs.drainRequests(drainCtx) + if shutdownErr != nil { + logger.Error("Error draining requests: %v", shutdownErr) + } + + // Shutdown the server + logger.Info("Shutting down HTTP server...") + if err := gs.server.Shutdown(shutdownCtx); err != nil { + logger.Error("Error shutting down server: %v", err) + if shutdownErr == nil { + shutdownErr = err + } + } + + logger.Info("Graceful shutdown complete") + close(gs.shutdownComplete) + }) + + return shutdownErr +} + +// drainRequests waits for in-flight requests to complete +func (gs *gracefulServer) drainRequests(ctx context.Context) error { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + startTime := time.Now() + + for { + inFlight := gs.inFlightRequests.Load() + + if inFlight == 0 { + logger.Info("All requests drained in %v", time.Since(startTime)) + return nil + } + + select { + case <-ctx.Done(): + logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight) + return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight) + case <-ticker.C: + logger.Debug("Waiting for %d in-flight requests to complete...", inFlight) + } + } +} + +// inFlightRequests returns the current number of in-flight requests +func (gs *gracefulServer) inFlightRequestsCount() int64 { + return gs.inFlightRequests.Load() +} + +// isShutdown returns true if the server is shutting down +func (gs *gracefulServer) isShutdown() bool { + return gs.isShuttingDown.Load() +} + +// wait blocks until shutdown is complete +func (gs *gracefulServer) wait() { + <-gs.shutdownComplete +} + +// healthCheckHandler returns a handler that responds to health checks +func (gs *gracefulServer) healthCheckHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if gs.isShutdown() { + http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"status":"healthy"}`)) + if err != nil { + logger.Warn("Failed to write health check response: %v", err) + } + } +} + +// readinessHandler returns a handler for readiness checks +func (gs *gracefulServer) readinessHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if gs.isShutdown() { + http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable) + return + } + + inFlight := gs.inFlightRequestsCount() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight) + } +} + +// serverManager manages a collection of server instances with graceful shutdown support. type serverManager struct { - instances map[string]Instance - mu sync.RWMutex + instances map[string]Instance + mu sync.RWMutex + shutdownCallbacks []ShutdownCallback + callbacksMu sync.Mutex } // NewManager creates a new server manager. func NewManager() Manager { return &serverManager{ - instances: make(map[string]Instance), + instances: make(map[string]Instance), + shutdownCallbacks: make([]ShutdownCallback, 0), } } @@ -74,7 +221,7 @@ func (sm *serverManager) Remove(name string) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := instance.Stop(ctx); err != nil { - logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err, context.Background()) + logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err) } delete(sm.instances, name) @@ -94,7 +241,6 @@ func (sm *serverManager) StartAll() error { } if len(startErrors) > 0 { - // In a real-world scenario, you might want a more sophisticated error handling strategy return fmt.Errorf("encountered errors while starting servers: %v", startErrors) } return nil @@ -102,6 +248,11 @@ func (sm *serverManager) StartAll() error { // StopAll gracefully shuts down all running server instances. func (sm *serverManager) StopAll() error { + return sm.StopAllWithContext(context.Background()) +} + +// StopAllWithContext gracefully shuts down all running server instances with a context. +func (sm *serverManager) StopAllWithContext(ctx context.Context) error { sm.mu.RLock() instancesToStop := make([]Instance, 0, len(sm.instances)) for _, instance := range sm.instances { @@ -109,19 +260,38 @@ func (sm *serverManager) StopAll() error { } sm.mu.RUnlock() - logger.Info("Shutting down all servers...", context.Background()) + logger.Info("Shutting down all servers...") + // Execute shutdown callbacks first + sm.callbacksMu.Lock() + callbacks := make([]ShutdownCallback, len(sm.shutdownCallbacks)) + copy(callbacks, sm.shutdownCallbacks) + sm.callbacksMu.Unlock() + + if len(callbacks) > 0 { + logger.Info("Executing %d shutdown callbacks...", len(callbacks)) + for i, cb := range callbacks { + if err := cb(ctx); err != nil { + logger.Error("Shutdown callback %d failed: %v", i+1, err) + } + } + } + + // Stop all instances in parallel var shutdownErrors []error var wg sync.WaitGroup + var errorsMu sync.Mutex for _, instance := range instancesToStop { wg.Add(1) go func(inst Instance) { defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second) defer cancel() - if err := inst.Stop(ctx); err != nil { - shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Addr(), err)) + if err := inst.Stop(shutdownCtx); err != nil { + errorsMu.Lock() + shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Name(), err)) + errorsMu.Unlock() } }(instance) } @@ -131,13 +301,13 @@ func (sm *serverManager) StopAll() error { if len(shutdownErrors) > 0 { return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors) } - logger.Info("All servers stopped gracefully.", context.Background()) + logger.Info("All servers stopped gracefully.") return nil } // RestartAll gracefully restarts all running server instances. func (sm *serverManager) RestartAll() error { - logger.Info("Restarting all servers...", context.Background()) + logger.Info("Restarting all servers...") if err := sm.StopAll(); err != nil { return fmt.Errorf("failed to stop servers during restart: %w", err) } @@ -148,7 +318,7 @@ func (sm *serverManager) RestartAll() error { if err := sm.StartAll(); err != nil { return fmt.Errorf("failed to start servers during restart: %w", err) } - logger.Info("All servers restarted successfully.", context.Background()) + logger.Info("All servers restarted successfully.") return nil } @@ -164,13 +334,46 @@ func (sm *serverManager) List() []Instance { return instances } +// RegisterShutdownCallback registers a callback to be called during shutdown. +func (sm *serverManager) RegisterShutdownCallback(cb ShutdownCallback) { + sm.callbacksMu.Lock() + defer sm.callbacksMu.Unlock() + sm.shutdownCallbacks = append(sm.shutdownCallbacks, cb) +} + +// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received. +func (sm *serverManager) ServeWithGracefulShutdown() error { + // Start all servers + if err := sm.StartAll(); err != nil { + return fmt.Errorf("failed to start servers: %w", err) + } + + logger.Info("All servers started. Waiting for shutdown signal...") + + // Wait for interrupt signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + + sig := <-sigChan + logger.Info("Received signal: %v, initiating graceful shutdown", sig) + + // Create context with timeout for shutdown + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + return sm.StopAllWithContext(ctx) +} + // serverInstance is a concrete implementation of the Instance interface. +// It wraps gracefulServer to provide graceful shutdown capabilities. type serverInstance struct { - cfg Config - httpServer *http.Server - mu sync.RWMutex - running bool - stopCh chan struct{} + cfg Config + gracefulServer *gracefulServer + certFile string // Path to certificate file (may be temporary for self-signed) + keyFile string // Path to key file (may be temporary for self-signed) + mu sync.RWMutex + running bool + serverErr chan error } // newInstance creates a new, unstarted server instance from a config. @@ -179,12 +382,29 @@ func newInstance(cfg Config) (*serverInstance, error) { return nil, fmt.Errorf("handler cannot be nil") } + // Set default timeouts + if cfg.ShutdownTimeout == 0 { + cfg.ShutdownTimeout = 30 * time.Second + } + if cfg.DrainTimeout == 0 { + cfg.DrainTimeout = 25 * time.Second + } + if cfg.ReadTimeout == 0 { + cfg.ReadTimeout = 15 * time.Second + } + if cfg.WriteTimeout == 0 { + cfg.WriteTimeout = 15 * time.Second + } + if cfg.IdleTimeout == 0 { + cfg.IdleTimeout = 60 * time.Second + } + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) var handler http.Handler = cfg.Handler // Wrap with GZIP handler if enabled if cfg.GZIP { - gz, err := gzhttp.NewWrapper(gzhttp.BestSpeed) + gz, err := gzhttp.NewWrapper() if err != nil { return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err) } @@ -194,20 +414,33 @@ func newInstance(cfg Config) (*serverInstance, error) { // Wrap with the panic recovery middleware handler = middleware.PanicRecovery(handler) - // Here you could add other default middleware like request logging, metrics, etc. + // Configure TLS if any TLS option is enabled + tlsConfig, certFile, keyFile, err := configureTLS(cfg) + if err != nil { + return nil, fmt.Errorf("failed to configure TLS: %w", err) + } - httpServer := &http.Server{ - Addr: addr, - Handler: handler, - ReadTimeout: 15 * time.Second, - WriteTimeout: 15 * time.Second, - IdleTimeout: 60 * time.Second, + // Create gracefulServer + gracefulSrv := &gracefulServer{ + server: &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + IdleTimeout: cfg.IdleTimeout, + TLSConfig: tlsConfig, + }, + shutdownTimeout: cfg.ShutdownTimeout, + drainTimeout: cfg.DrainTimeout, + shutdownComplete: make(chan struct{}), } return &serverInstance{ - cfg: cfg, - httpServer: httpServer, - stopCh: make(chan struct{}), + cfg: cfg, + gracefulServer: gracefulSrv, + certFile: certFile, + keyFile: keyFile, + serverErr: make(chan error, 1), }, nil } @@ -220,42 +453,69 @@ func (s *serverInstance) Start() error { return fmt.Errorf("server '%s' is already running", s.cfg.Name) } - hasSSL := s.cfg.SSLCert != "" && s.cfg.SSLKey != "" + // Determine if we're using TLS + useTLS := s.cfg.SSLCert != "" || s.cfg.SSLKey != "" || s.cfg.SelfSignedSSL || s.cfg.AutoTLS + + // Wrap handler with request tracking + s.gracefulServer.server.Handler = s.gracefulServer.trackRequestsMiddleware(s.gracefulServer.server.Handler) go func() { defer func() { s.mu.Lock() s.running = false s.mu.Unlock() - logger.Info("Server '%s' stopped.", s.cfg.Name, context.Background()) + logger.Info("Server '%s' stopped.", s.cfg.Name) }() var err error protocol := "HTTP" - if hasSSL { + if useTLS { protocol = "HTTPS" - // Configure TLS + HTTP/2 - s.httpServer.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, + logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr()) + + // For AutoTLS, we need to use a TLS listener + if s.cfg.AutoTLS { + // Create listener + ln, lnErr := net.Listen("tcp", s.gracefulServer.server.Addr) + if lnErr != nil { + err = fmt.Errorf("failed to create listener: %w", lnErr) + } else { + // Wrap with TLS + tlsListener := tls.NewListener(ln, s.gracefulServer.server.TLSConfig) + err = s.gracefulServer.server.Serve(tlsListener) + } + } else { + // Use certificate files (regular SSL or self-signed) + err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile) } - logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background()) err = s.httpServer.ListenAndServeTLS(s.cfg.SSLCert, s.cfg.SSLKey) } else { - logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background()) - err = s.httpServer.ListenAndServe() + logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr()) + err = s.gracefulServer.server.ListenAndServe() } - // If the server stopped for a reason other than a graceful shutdown, log the error. + // If the server stopped for a reason other than a graceful shutdown, log and report the error. if err != nil && err != http.ErrServerClosed { - logger.Error("Server '%s' failed: %v", s.cfg.Name, err, context.Background()) + logger.Error("Server '%s' failed: %v", s.cfg.Name, err) + select { + case s.serverErr <- err: + default: + } } }() s.running = true // A small delay to allow the goroutine to start and potentially fail on binding. - // A more robust solution might involve a channel signal. time.Sleep(50 * time.Millisecond) + // Check if the server failed to start + select { + case err := <-s.serverErr: + s.running = false + return err + default: + } + return nil } @@ -269,7 +529,7 @@ func (s *serverInstance) Stop(ctx context.Context) error { } logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name) - err := s.httpServer.Shutdown(ctx) + err := s.gracefulServer.shutdown(ctx) if err == nil { s.running = false } @@ -278,5 +538,35 @@ func (s *serverInstance) Stop(ctx context.Context) error { // Addr returns the network address the server is listening on. func (s *serverInstance) Addr() string { - return s.httpServer.Addr + return s.gracefulServer.server.Addr +} + +// Name returns the server instance name. +func (s *serverInstance) Name() string { + return s.cfg.Name +} + +// HealthCheckHandler returns a handler that responds to health checks. +func (s *serverInstance) HealthCheckHandler() http.HandlerFunc { + return s.gracefulServer.healthCheckHandler() +} + +// ReadinessHandler returns a handler for readiness checks. +func (s *serverInstance) ReadinessHandler() http.HandlerFunc { + return s.gracefulServer.readinessHandler() +} + +// InFlightRequests returns the current number of in-flight requests. +func (s *serverInstance) InFlightRequests() int64 { + return s.gracefulServer.inFlightRequestsCount() +} + +// IsShuttingDown returns true if the server is shutting down. +func (s *serverInstance) IsShuttingDown() bool { + return s.gracefulServer.isShutdown() +} + +// Wait blocks until shutdown is complete. +func (s *serverInstance) Wait() { + s.gracefulServer.wait() } diff --git a/pkg/server/manager_test.go b/pkg/server/manager_test.go index 1ff08f7..8ad123b 100644 --- a/pkg/server/manager_test.go +++ b/pkg/server/manager_test.go @@ -1,10 +1,12 @@ package server import ( + "context" "fmt" "io" "net" "net/http" + "sync" "testing" "time" @@ -123,3 +125,204 @@ func TestManagerErrorCases(t *testing.T) { _, err = sm.Add(config3) require.Error(t, err, "should not be able to add a server with a nil handler") } + +func TestGracefulShutdown(t *testing.T) { + logger.Init(true) + sm := NewManager() + + requestsHandled := 0 + var requestsMu sync.Mutex + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestsMu.Lock() + requestsHandled++ + requestsMu.Unlock() + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }) + + testPort := getFreePort(t) + instance, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: handler, + DrainTimeout: 2 * time.Second, + }) + require.NoError(t, err) + + err = sm.StartAll() + require.NoError(t, err) + + // Give server time to start + time.Sleep(100 * time.Millisecond) + + // Send some concurrent requests + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + client := &http.Client{Timeout: 5 * time.Second} + url := fmt.Sprintf("http://localhost:%d", testPort) + resp, err := client.Get(url) + if err == nil { + resp.Body.Close() + } + }() + } + + // Wait a bit for requests to start + time.Sleep(50 * time.Millisecond) + + // Check in-flight requests + inFlight := instance.InFlightRequests() + assert.Greater(t, inFlight, int64(0), "Should have in-flight requests") + + // Stop the server + err = sm.StopAll() + require.NoError(t, err) + + // Wait for all requests to complete + wg.Wait() + + // Verify all requests were handled + requestsMu.Lock() + handled := requestsHandled + requestsMu.Unlock() + assert.GreaterOrEqual(t, handled, 1, "At least some requests should have been handled") + + // Verify no in-flight requests + assert.Equal(t, int64(0), instance.InFlightRequests(), "Should have no in-flight requests after shutdown") +} + +func TestHealthAndReadinessEndpoints(t *testing.T) { + logger.Init(true) + sm := NewManager() + + mux := http.NewServeMux() + testPort := getFreePort(t) + + instance, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: mux, + }) + require.NoError(t, err) + + // Add health and readiness endpoints + mux.HandleFunc("/health", instance.HealthCheckHandler()) + mux.HandleFunc("/ready", instance.ReadinessHandler()) + + err = sm.StartAll() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + client := &http.Client{Timeout: 2 * time.Second} + baseURL := fmt.Sprintf("http://localhost:%d", testPort) + + // Test health endpoint + resp, err := client.Get(baseURL + "/health") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + assert.Contains(t, string(body), "healthy") + + // Test readiness endpoint + resp, err = client.Get(baseURL + "/ready") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, _ = io.ReadAll(resp.Body) + resp.Body.Close() + assert.Contains(t, string(body), "ready") + assert.Contains(t, string(body), "in_flight_requests") + + // Stop the server + sm.StopAll() +} + +func TestRequestRejectionDuringShutdown(t *testing.T) { + logger.Init(true) + sm := NewManager() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }) + + testPort := getFreePort(t) + _, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: handler, + DrainTimeout: 1 * time.Second, + }) + require.NoError(t, err) + + err = sm.StartAll() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Start shutdown in background + go func() { + time.Sleep(50 * time.Millisecond) + sm.StopAll() + }() + + // Give shutdown time to start + time.Sleep(100 * time.Millisecond) + + // Try to make a request after shutdown started + client := &http.Client{Timeout: 2 * time.Second} + url := fmt.Sprintf("http://localhost:%d", testPort) + resp, err := client.Get(url) + + // The request should either fail (connection refused) or get 503 + if err == nil { + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Should get 503 during shutdown") + resp.Body.Close() + } +} + +func TestShutdownCallbacks(t *testing.T) { + logger.Init(true) + sm := NewManager() + + callbackExecuted := false + var callbackMu sync.Mutex + + sm.RegisterShutdownCallback(func(ctx context.Context) error { + callbackMu.Lock() + callbackExecuted = true + callbackMu.Unlock() + return nil + }) + + testPort := getFreePort(t) + _, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: http.NewServeMux(), + }) + require.NoError(t, err) + + err = sm.StartAll() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = sm.StopAll() + require.NoError(t, err) + + callbackMu.Lock() + executed := callbackExecuted + callbackMu.Unlock() + + assert.True(t, executed, "Shutdown callback should have been executed") +} diff --git a/pkg/server/shutdown.go b/pkg/server/shutdown.go deleted file mode 100644 index 9960405..0000000 --- a/pkg/server/shutdown.go +++ /dev/null @@ -1,296 +0,0 @@ -package server - -import ( - "context" - "fmt" - "net/http" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/bitechdev/ResolveSpec/pkg/logger" -) - -// GracefulServer wraps http.Server with graceful shutdown capabilities -type GracefulServer struct { - server *http.Server - shutdownTimeout time.Duration - drainTimeout time.Duration - inFlightRequests atomic.Int64 - isShuttingDown atomic.Bool - shutdownOnce sync.Once - shutdownComplete chan struct{} -} - -// Config holds configuration for the graceful server -type Config struct { - // Addr is the server address (e.g., ":8080") - Addr string - - // Handler is the HTTP handler - Handler http.Handler - - // ShutdownTimeout is the maximum time to wait for graceful shutdown - // Default: 30 seconds - ShutdownTimeout time.Duration - - // DrainTimeout is the time to wait for in-flight requests to complete - // before forcing shutdown. Default: 25 seconds - DrainTimeout time.Duration - - // ReadTimeout is the maximum duration for reading the entire request - ReadTimeout time.Duration - - // WriteTimeout is the maximum duration before timing out writes of the response - WriteTimeout time.Duration - - // IdleTimeout is the maximum amount of time to wait for the next request - IdleTimeout time.Duration -} - -// NewGracefulServer creates a new graceful server -func NewGracefulServer(config Config) *GracefulServer { - if config.ShutdownTimeout == 0 { - config.ShutdownTimeout = 30 * time.Second - } - if config.DrainTimeout == 0 { - config.DrainTimeout = 25 * time.Second - } - if config.ReadTimeout == 0 { - config.ReadTimeout = 10 * time.Second - } - if config.WriteTimeout == 0 { - config.WriteTimeout = 10 * time.Second - } - if config.IdleTimeout == 0 { - config.IdleTimeout = 120 * time.Second - } - - gs := &GracefulServer{ - server: &http.Server{ - Addr: config.Addr, - Handler: config.Handler, - ReadTimeout: config.ReadTimeout, - WriteTimeout: config.WriteTimeout, - IdleTimeout: config.IdleTimeout, - }, - shutdownTimeout: config.ShutdownTimeout, - drainTimeout: config.DrainTimeout, - shutdownComplete: make(chan struct{}), - } - - return gs -} - -// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown -func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if shutting down - if gs.isShuttingDown.Load() { - http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable) - return - } - - // Increment in-flight counter - gs.inFlightRequests.Add(1) - defer gs.inFlightRequests.Add(-1) - - // Serve the request - next.ServeHTTP(w, r) - }) -} - -// ListenAndServe starts the server and handles graceful shutdown -func (gs *GracefulServer) ListenAndServe() error { - // Wrap handler with request tracking - gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler) - - // Start server in goroutine - serverErr := make(chan error, 1) - go func() { - logger.Info("Starting server on %s", gs.server.Addr) - if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - serverErr <- err - } - close(serverErr) - }() - - // Wait for interrupt signal - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) - - select { - case err := <-serverErr: - return err - case sig := <-sigChan: - logger.Info("Received signal: %v, initiating graceful shutdown", sig) - return gs.Shutdown(context.Background()) - } -} - -// Shutdown performs graceful shutdown with request draining -func (gs *GracefulServer) Shutdown(ctx context.Context) error { - var shutdownErr error - - gs.shutdownOnce.Do(func() { - logger.Info("Starting graceful shutdown...") - - // Mark as shutting down (new requests will be rejected) - gs.isShuttingDown.Store(true) - - // Create context with timeout - shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout) - defer cancel() - - // Wait for in-flight requests to complete (with drain timeout) - drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout) - defer drainCancel() - - shutdownErr = gs.drainRequests(drainCtx) - if shutdownErr != nil { - logger.Error("Error draining requests: %v", shutdownErr) - } - - // Shutdown the server - logger.Info("Shutting down HTTP server...") - if err := gs.server.Shutdown(shutdownCtx); err != nil { - logger.Error("Error shutting down server: %v", err) - if shutdownErr == nil { - shutdownErr = err - } - } - - logger.Info("Graceful shutdown complete") - close(gs.shutdownComplete) - }) - - return shutdownErr -} - -// drainRequests waits for in-flight requests to complete -func (gs *GracefulServer) drainRequests(ctx context.Context) error { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - startTime := time.Now() - - for { - inFlight := gs.inFlightRequests.Load() - - if inFlight == 0 { - logger.Info("All requests drained in %v", time.Since(startTime)) - return nil - } - - select { - case <-ctx.Done(): - logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight) - return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight) - case <-ticker.C: - logger.Debug("Waiting for %d in-flight requests to complete...", inFlight) - } - } -} - -// InFlightRequests returns the current number of in-flight requests -func (gs *GracefulServer) InFlightRequests() int64 { - return gs.inFlightRequests.Load() -} - -// IsShuttingDown returns true if the server is shutting down -func (gs *GracefulServer) IsShuttingDown() bool { - return gs.isShuttingDown.Load() -} - -// Wait blocks until shutdown is complete -func (gs *GracefulServer) Wait() { - <-gs.shutdownComplete -} - -// HealthCheckHandler returns a handler that responds to health checks -// Returns 200 OK when healthy, 503 Service Unavailable when shutting down -func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if gs.IsShuttingDown() { - http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(`{"status":"healthy"}`)) - if err != nil { - logger.Warn("Failed to write. %v", err) - } - } -} - -// ReadinessHandler returns a handler for readiness checks -// Includes in-flight request count -func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if gs.IsShuttingDown() { - http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable) - return - } - - inFlight := gs.InFlightRequests() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight) - } -} - -// ShutdownCallback is a function called during shutdown -type ShutdownCallback func(context.Context) error - -// shutdownCallbacks stores registered shutdown callbacks -var ( - shutdownCallbacks []ShutdownCallback - shutdownCallbacksMu sync.Mutex -) - -// RegisterShutdownCallback registers a callback to be called during shutdown -// Useful for cleanup tasks like closing database connections, flushing metrics, etc. -func RegisterShutdownCallback(cb ShutdownCallback) { - shutdownCallbacksMu.Lock() - defer shutdownCallbacksMu.Unlock() - shutdownCallbacks = append(shutdownCallbacks, cb) -} - -// executeShutdownCallbacks runs all registered shutdown callbacks -func executeShutdownCallbacks(ctx context.Context) error { - shutdownCallbacksMu.Lock() - callbacks := make([]ShutdownCallback, len(shutdownCallbacks)) - copy(callbacks, shutdownCallbacks) - shutdownCallbacksMu.Unlock() - - var errors []error - for i, cb := range callbacks { - logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks)) - if err := cb(ctx); err != nil { - logger.Error("Shutdown callback %d failed: %v", i+1, err) - errors = append(errors, err) - } - } - - if len(errors) > 0 { - return fmt.Errorf("shutdown callbacks failed: %v", errors) - } - - return nil -} - -// ShutdownWithCallbacks performs shutdown and executes all registered callbacks -func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error { - // Execute callbacks first - if err := executeShutdownCallbacks(ctx); err != nil { - logger.Error("Error executing shutdown callbacks: %v", err) - } - - // Then shutdown the server - return gs.Shutdown(ctx) -} diff --git a/pkg/server/shutdown_test.go b/pkg/server/shutdown_test.go deleted file mode 100644 index 8eef439..0000000 --- a/pkg/server/shutdown_test.go +++ /dev/null @@ -1,231 +0,0 @@ -package server - -import ( - "context" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" -) - -func TestGracefulServerTrackRequests(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusOK) - }), - }) - - handler := srv.TrackRequestsMiddleware(srv.server.Handler) - - // Start some requests - var wg sync.WaitGroup - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - }() - } - - // Wait a bit for requests to start - time.Sleep(10 * time.Millisecond) - - // Check in-flight count - inFlight := srv.InFlightRequests() - if inFlight == 0 { - t.Error("Should have in-flight requests") - } - - // Wait for all requests to complete - wg.Wait() - - // Check that counter is back to zero - inFlight = srv.InFlightRequests() - if inFlight != 0 { - t.Errorf("In-flight requests should be 0, got %d", inFlight) - } -} - -func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }), - }) - - handler := srv.TrackRequestsMiddleware(srv.server.Handler) - - // Mark as shutting down - srv.isShuttingDown.Store(true) - - // Try to make a request - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - // Should get 503 - if w.Code != http.StatusServiceUnavailable { - t.Errorf("Expected 503, got %d", w.Code) - } -} - -func TestHealthCheckHandler(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - }) - - handler := srv.HealthCheckHandler() - - // Healthy - t.Run("Healthy", func(t *testing.T) { - req := httptest.NewRequest("GET", "/health", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected 200, got %d", w.Code) - } - - if w.Body.String() != `{"status":"healthy"}` { - t.Errorf("Unexpected body: %s", w.Body.String()) - } - }) - - // Shutting down - t.Run("ShuttingDown", func(t *testing.T) { - srv.isShuttingDown.Store(true) - - req := httptest.NewRequest("GET", "/health", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusServiceUnavailable { - t.Errorf("Expected 503, got %d", w.Code) - } - }) -} - -func TestReadinessHandler(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - }) - - handler := srv.ReadinessHandler() - - // Ready with no in-flight requests - t.Run("Ready", func(t *testing.T) { - req := httptest.NewRequest("GET", "/ready", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected 200, got %d", w.Code) - } - - body := w.Body.String() - if body != `{"ready":true,"in_flight_requests":0}` { - t.Errorf("Unexpected body: %s", body) - } - }) - - // Not ready during shutdown - t.Run("NotReady", func(t *testing.T) { - srv.isShuttingDown.Store(true) - - req := httptest.NewRequest("GET", "/ready", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusServiceUnavailable { - t.Errorf("Expected 503, got %d", w.Code) - } - }) -} - -func TestShutdownCallbacks(t *testing.T) { - callbackExecuted := false - - RegisterShutdownCallback(func(ctx context.Context) error { - callbackExecuted = true - return nil - }) - - ctx := context.Background() - err := executeShutdownCallbacks(ctx) - - if err != nil { - t.Errorf("executeShutdownCallbacks() error = %v", err) - } - - if !callbackExecuted { - t.Error("Shutdown callback was not executed") - } - - // Reset for other tests - shutdownCallbacks = nil -} - -func TestDrainRequests(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - DrainTimeout: 1 * time.Second, - }) - - // Simulate in-flight requests - srv.inFlightRequests.Add(3) - - // Start draining in background - go func() { - time.Sleep(100 * time.Millisecond) - // Simulate requests completing - srv.inFlightRequests.Add(-3) - }() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := srv.drainRequests(ctx) - if err != nil { - t.Errorf("drainRequests() error = %v", err) - } - - if srv.InFlightRequests() != 0 { - t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests()) - } -} - -func TestDrainRequestsTimeout(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - DrainTimeout: 100 * time.Millisecond, - }) - - // Simulate in-flight requests that don't complete - srv.inFlightRequests.Add(5) - - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - - err := srv.drainRequests(ctx) - if err == nil { - t.Error("drainRequests() should timeout with error") - } - - // Cleanup - srv.inFlightRequests.Add(-5) -} - -func TestGetClientIP(t *testing.T) { - // This test is in ratelimit_test.go since getClientIP is used by rate limiter - // Including here for completeness of server tests -} diff --git a/pkg/server/tls.go b/pkg/server/tls.go new file mode 100644 index 0000000..5961291 --- /dev/null +++ b/pkg/server/tls.go @@ -0,0 +1,190 @@ +package server + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "path/filepath" + "time" + + "golang.org/x/crypto/acme/autocert" +) + +// generateSelfSignedCert generates a self-signed certificate for the given host. +// Returns the certificate and private key in PEM format. +func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) { + // Generate private key + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + // Create certificate template + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"ResolveSpec Self-Signed"}, + CommonName: host, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Add host as DNS name or IP address + if ip := net.ParseIP(host); ip != nil { + template.IPAddresses = []net.IP{ip} + } else { + template.DNSNames = []string{host} + } + + // Create certificate + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Encode certificate to PEM + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + // Encode private key to PEM + privBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) + + return certPEM, keyPEM, nil +} + +// saveCertToTempFiles saves certificate and key PEM data to temporary files. +// Returns the file paths for the certificate and key. +func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile string, err error) { + // Create temporary directory + tmpDir, err := os.MkdirTemp("", "resolvespec-certs-*") + if err != nil { + return "", "", fmt.Errorf("failed to create temp directory: %w", err) + } + + certFile = filepath.Join(tmpDir, "cert.pem") + keyFile = filepath.Join(tmpDir, "key.pem") + + // Write certificate + if err := os.WriteFile(certFile, certPEM, 0600); err != nil { + os.RemoveAll(tmpDir) + return "", "", fmt.Errorf("failed to write certificate: %w", err) + } + + // Write key + if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { + os.RemoveAll(tmpDir) + return "", "", fmt.Errorf("failed to write private key: %w", err) + } + + return certFile, keyFile, nil +} + +// setupAutoTLS configures automatic TLS certificate management using Let's Encrypt. +// Returns a TLS config that can be used with http.Server. +func setupAutoTLS(domains []string, email, cacheDir string) (*tls.Config, error) { + if len(domains) == 0 { + return nil, fmt.Errorf("at least one domain must be specified for AutoTLS") + } + + // Set default cache directory + if cacheDir == "" { + cacheDir = "./certs-cache" + } + + // Create cache directory if it doesn't exist + if err := os.MkdirAll(cacheDir, 0700); err != nil { + return nil, fmt.Errorf("failed to create certificate cache directory: %w", err) + } + + // Create autocert manager + m := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(cacheDir), + HostPolicy: autocert.HostWhitelist(domains...), + Email: email, + } + + // Create TLS config + tlsConfig := m.TLSConfig() + tlsConfig.MinVersion = tls.VersionTLS12 + + return tlsConfig, nil +} + +// configureTLS configures TLS for the server based on the provided configuration. +// Returns the TLS config and certificate/key file paths (if applicable). +func configureTLS(cfg Config) (*tls.Config, string, string, error) { + // Option 1: Certificate files provided + if cfg.SSLCert != "" && cfg.SSLKey != "" { + // Validate that files exist + if _, err := os.Stat(cfg.SSLCert); os.IsNotExist(err) { + return nil, "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert) + } + if _, err := os.Stat(cfg.SSLKey); os.IsNotExist(err) { + return nil, "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey) + } + + // Return basic TLS config - cert/key will be loaded by ListenAndServeTLS + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + return tlsConfig, cfg.SSLCert, cfg.SSLKey, nil + } + + // Option 2: Auto TLS (Let's Encrypt) + if cfg.AutoTLS { + tlsConfig, err := setupAutoTLS(cfg.AutoTLSDomains, cfg.AutoTLSEmail, cfg.AutoTLSCacheDir) + if err != nil { + return nil, "", "", fmt.Errorf("failed to setup AutoTLS: %w", err) + } + return tlsConfig, "", "", nil + } + + // Option 3: Self-signed certificate + if cfg.SelfSignedSSL { + host := cfg.Host + if host == "" || host == "0.0.0.0" { + host = "localhost" + } + + certPEM, keyPEM, err := generateSelfSignedCert(host) + if err != nil { + return nil, "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err) + } + + certFile, keyFile, err := saveCertToTempFiles(certPEM, keyPEM) + if err != nil { + return nil, "", "", fmt.Errorf("failed to save self-signed certificate: %w", err) + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + return tlsConfig, certFile, keyFile, nil + } + + return nil, "", "", nil +}