Better server manager

This commit is contained in:
Hein
2025-12-29 17:19:16 +02:00
parent 8f83e8fdc1
commit d4a6f9c4c2
8 changed files with 1475 additions and 857 deletions

View File

@@ -1,233 +1,314 @@
# Server Package
Graceful HTTP server with request draining and shutdown coordination.
Production-ready HTTP server manager with graceful shutdown, request draining, and comprehensive TLS/HTTPS support.
## Features
**Multiple Server Management** - Run multiple HTTP/HTTPS servers concurrently
**Graceful Shutdown** - Handles SIGINT/SIGTERM with request draining
**Automatic Request Rejection** - New requests get 503 during shutdown
**Health & Readiness Endpoints** - Kubernetes-ready health checks
**Shutdown Callbacks** - Register cleanup functions (DB, cache, metrics)
**Comprehensive TLS Support**:
- Certificate files (production)
- Self-signed certificates (development/testing)
- Let's Encrypt / AutoTLS (automatic certificate management)
**GZIP Compression** - Optional response compression
**Panic Recovery** - Automatic panic recovery middleware
**Configurable Timeouts** - Read, write, idle, drain, and shutdown timeouts
## Quick Start
### Single Server
```go
import "github.com/bitechdev/ResolveSpec/pkg/server"
// Create server
srv := server.NewGracefulServer(server.Config{
Addr: ":8080",
Handler: router,
// Create server manager
mgr := server.NewManager()
// Add server
_, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: myRouter,
GZIP: true,
})
// Start server (blocks until shutdown signal)
if err := srv.ListenAndServe(); err != nil {
// Start and wait for shutdown signal
if err := mgr.ServeWithGracefulShutdown(); err != nil {
log.Fatal(err)
}
```
## Features
### Multiple Servers
✅ Graceful shutdown on SIGINT/SIGTERM
✅ Request draining (waits for in-flight requests)
✅ Automatic request rejection during shutdown
✅ Health and readiness endpoints
✅ Shutdown callbacks for cleanup
✅ Configurable timeouts
```go
mgr := server.NewManager()
// Public API
mgr.Add(server.Config{
Name: "public-api",
Port: 8080,
Handler: publicRouter,
})
// Admin API
mgr.Add(server.Config{
Name: "admin-api",
Port: 8081,
Handler: adminRouter,
})
// Start all and wait
mgr.ServeWithGracefulShutdown()
```
## HTTPS/TLS Configuration
### Option 1: Certificate Files (Production)
```go
mgr.Add(server.Config{
Name: "https-server",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
SSLCert: "/etc/ssl/certs/server.crt",
SSLKey: "/etc/ssl/private/server.key",
})
```
### Option 2: Self-Signed Certificate (Development)
```go
mgr.Add(server.Config{
Name: "dev-server",
Host: "localhost",
Port: 8443,
Handler: handler,
SelfSignedSSL: true, // Auto-generates certificate
})
```
### Option 3: Let's Encrypt / AutoTLS (Production)
```go
mgr.Add(server.Config{
Name: "prod-server",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com",
AutoTLSCacheDir: "./certs-cache", // Certificate cache directory
})
```
## Configuration
```go
config := server.Config{
// Server address
Addr: ":8080",
server.Config{
// Basic configuration
Name: "my-server", // Server name (required)
Host: "0.0.0.0", // Bind address
Port: 8080, // Port (required)
Handler: myRouter, // HTTP handler (required)
Description: "My API server", // Optional description
// HTTP handler
Handler: myRouter,
// Features
GZIP: true, // Enable GZIP compression
// Maximum time for graceful shutdown (default: 30s)
ShutdownTimeout: 30 * time.Second,
// TLS/HTTPS (choose one option)
SSLCert: "/path/to/cert.pem", // Certificate file
SSLKey: "/path/to/key.pem", // Key file
SelfSignedSSL: false, // Auto-generate self-signed cert
AutoTLS: false, // Let's Encrypt
AutoTLSDomains: []string{}, // Domains for AutoTLS
AutoTLSEmail: "", // Email for Let's Encrypt
AutoTLSCacheDir: "./certs-cache", // Cert cache directory
// Time to wait for in-flight requests (default: 25s)
DrainTimeout: 25 * time.Second,
// Request read timeout (default: 10s)
ReadTimeout: 10 * time.Second,
// Response write timeout (default: 10s)
WriteTimeout: 10 * time.Second,
// Idle connection timeout (default: 120s)
IdleTimeout: 120 * time.Second,
// Timeouts
ShutdownTimeout: 30 * time.Second, // Max shutdown time
DrainTimeout: 25 * time.Second, // Request drain timeout
ReadTimeout: 15 * time.Second, // Request read timeout
WriteTimeout: 15 * time.Second, // Response write timeout
IdleTimeout: 60 * time.Second, // Idle connection timeout
}
srv := server.NewGracefulServer(config)
```
## Shutdown Behavior
## Graceful Shutdown
**Signal received (SIGINT/SIGTERM):**
### Automatic (Recommended)
1. **Mark as shutting down** - New requests get 503
2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests
3. **Shutdown server** - Close listeners and connections
4. **Execute callbacks** - Run registered cleanup functions
```go
mgr := server.NewManager()
// Add servers...
// This blocks until SIGINT/SIGTERM
mgr.ServeWithGracefulShutdown()
```
Time Event
─────────────────────────────────────────
0s Signal received: SIGTERM
├─ Mark as shutting down
├─ Reject new requests (503)
└─ Start draining...
1s In-flight: 50 requests
2s In-flight: 32 requests
3s In-flight: 12 requests
4s In-flight: 3 requests
5s In-flight: 0 requests ✓
└─ All requests drained
### Manual Control
5s Execute shutdown callbacks
6s Shutdown complete
```go
mgr := server.NewManager()
// Add and start servers
mgr.StartAll()
// Later... stop gracefully
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := mgr.StopAllWithContext(ctx); err != nil {
log.Printf("Shutdown error: %v", err)
}
```
### Shutdown Callbacks
Register cleanup functions to run during shutdown:
```go
// Close database
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Closing database...")
return db.Close()
})
// Flush metrics
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Flushing metrics...")
return metrics.Flush(ctx)
})
// Close cache
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Closing cache...")
return cache.Close()
})
```
## Health Checks
### Health Endpoint
Returns 200 when healthy, 503 when shutting down:
### Adding Health Endpoints
```go
router.HandleFunc("/health", srv.HealthCheckHandler())
instance, _ := mgr.Add(server.Config{
Name: "api-server",
Port: 8080,
Handler: router,
})
// Add health endpoints to your router
router.HandleFunc("/health", instance.HealthCheckHandler())
router.HandleFunc("/ready", instance.ReadinessHandler())
```
**Response (healthy):**
### Health Endpoint
Returns server health status:
**Healthy (200 OK):**
```json
{"status":"healthy"}
```
**Response (shutting down):**
**Shutting Down (503 Service Unavailable):**
```json
{"status":"shutting_down"}
```
### Readiness Endpoint
Includes in-flight request count:
Returns readiness with in-flight request count:
```go
router.HandleFunc("/ready", srv.ReadinessHandler())
```
**Response:**
**Ready (200 OK):**
```json
{"ready":true,"in_flight_requests":12}
```
**During shutdown:**
**Not Ready (503 Service Unavailable):**
```json
{"ready":false,"reason":"shutting_down"}
```
## Shutdown Callbacks
## Shutdown Behavior
Register cleanup functions to run during shutdown:
When a shutdown signal (SIGINT/SIGTERM) is received:
```go
// Close database
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Closing database connection...")
return db.Close()
})
1. **Mark as shutting down** → New requests get 503
2. **Execute callbacks** → Run cleanup functions
3. **Drain requests** → Wait up to `DrainTimeout` for in-flight requests
4. **Shutdown servers** → Close listeners and connections
// Flush metrics
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Flushing metrics...")
return metricsProvider.Flush(ctx)
})
```
Time Event
─────────────────────────────────────────
0s Signal received: SIGTERM
├─ Mark servers as shutting down
├─ Reject new requests (503)
└─ Execute shutdown callbacks
// Close cache
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Closing cache...")
return cache.Close()
})
1s Callbacks complete
└─ Start draining requests...
2s In-flight: 50 requests
3s In-flight: 32 requests
4s In-flight: 12 requests
5s In-flight: 3 requests
6s In-flight: 0 requests ✓
└─ All requests drained
6s Shutdown servers
7s All servers stopped ✓
```
## Complete Example
## Server Management
### Get Server Instance
```go
package main
import (
"context"
"log"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/bitechdev/ResolveSpec/pkg/server"
"github.com/gorilla/mux"
)
func main() {
// Initialize metrics
metricsProvider := metrics.NewPrometheusProvider()
metrics.SetProvider(metricsProvider)
// Create router
router := mux.NewRouter()
// Apply middleware
rateLimiter := middleware.NewRateLimiter(100, 20)
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB)
sanitizer := middleware.DefaultSanitizer()
router.Use(rateLimiter.Middleware)
router.Use(sizeLimiter.Middleware)
router.Use(sanitizer.Middleware)
router.Use(metricsProvider.Middleware)
// API routes
router.HandleFunc("/api/data", dataHandler)
// Create graceful server
srv := server.NewGracefulServer(server.Config{
Addr: ":8080",
Handler: router,
ShutdownTimeout: 30 * time.Second,
DrainTimeout: 25 * time.Second,
})
// Health checks
router.HandleFunc("/health", srv.HealthCheckHandler())
router.HandleFunc("/ready", srv.ReadinessHandler())
// Metrics endpoint
router.Handle("/metrics", metricsProvider.Handler())
// Register shutdown callbacks
server.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Cleanup: Flushing metrics...")
return nil
})
server.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Cleanup: Closing database...")
// return db.Close()
return nil
})
// Start server (blocks until shutdown)
log.Printf("Starting server on :8080")
if err := srv.ListenAndServe(); err != nil {
log.Fatal(err)
}
// Wait for shutdown to complete
srv.Wait()
log.Println("Server stopped")
instance, err := mgr.Get("api-server")
if err != nil {
log.Fatal(err)
}
func dataHandler(w http.ResponseWriter, r *http.Request) {
// Your handler logic
time.Sleep(100 * time.Millisecond) // Simulate work
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"success"}`))
// Check status
fmt.Printf("Address: %s\n", instance.Addr())
fmt.Printf("Name: %s\n", instance.Name())
fmt.Printf("In-flight: %d\n", instance.InFlightRequests())
fmt.Printf("Shutting down: %v\n", instance.IsShuttingDown())
```
### List All Servers
```go
instances := mgr.List()
for _, instance := range instances {
fmt.Printf("Server: %s at %s\n", instance.Name(), instance.Addr())
}
```
### Remove Server
```go
// Stop and remove a server
if err := mgr.Remove("api-server"); err != nil {
log.Printf("Error removing server: %v", err)
}
```
### Restart All Servers
```go
// Gracefully restart all servers
if err := mgr.RestartAll(); err != nil {
log.Printf("Error restarting: %v", err)
}
```
@@ -250,23 +331,21 @@ spec:
ports:
- containerPort: 8080
# Liveness probe - is app running?
# Liveness probe
livenessProbe:
httpGet:
path: /health
port: 8080
initialDelaySeconds: 10
periodSeconds: 10
timeoutSeconds: 5
# Readiness probe - can app handle traffic?
# Readiness probe
readinessProbe:
httpGet:
path: /ready
port: 8080
initialDelaySeconds: 5
periodSeconds: 5
timeoutSeconds: 3
# Graceful shutdown
lifecycle:
@@ -274,26 +353,12 @@ spec:
exec:
command: ["/bin/sh", "-c", "sleep 5"]
# Environment
env:
- name: SHUTDOWN_TIMEOUT
value: "30"
```
### Service
```yaml
apiVersion: v1
kind: Service
metadata:
name: myapp
spec:
selector:
app: myapp
ports:
- port: 80
targetPort: 8080
type: LoadBalancer
# Allow time for graceful shutdown
terminationGracePeriodSeconds: 35
```
## Docker Compose
@@ -312,8 +377,70 @@ services:
interval: 10s
timeout: 5s
retries: 3
start_period: 10s
stop_grace_period: 35s # Slightly longer than shutdown timeout
stop_grace_period: 35s
```
## Complete Example
```go
package main
import (
"context"
"log"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/server"
)
func main() {
// Create server manager
mgr := server.NewManager()
// Register shutdown callbacks
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Cleanup: Closing database...")
// return db.Close()
return nil
})
// Create router
router := http.NewServeMux()
router.HandleFunc("/api/data", dataHandler)
// Add server
instance, err := mgr.Add(server.Config{
Name: "api-server",
Host: "0.0.0.0",
Port: 8080,
Handler: router,
GZIP: true,
ShutdownTimeout: 30 * time.Second,
DrainTimeout: 25 * time.Second,
})
if err != nil {
log.Fatal(err)
}
// Add health endpoints
router.HandleFunc("/health", instance.HealthCheckHandler())
router.HandleFunc("/ready", instance.ReadinessHandler())
// Start and wait for shutdown
log.Println("Starting server on :8080")
if err := mgr.ServeWithGracefulShutdown(); err != nil {
log.Printf("Server stopped: %v", err)
}
log.Println("Server shutdown complete")
}
func dataHandler(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond) // Simulate work
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"success"}`))
}
```
## Testing Graceful Shutdown
@@ -330,7 +457,7 @@ SERVER_PID=$!
# Wait for server to start
sleep 2
# Send some requests
# Send requests
for i in {1..10}; do
curl http://localhost:8080/api/data &
done
@@ -341,7 +468,7 @@ sleep 1
# Send shutdown signal
kill -TERM $SERVER_PID
# Try to send more requests (should get 503)
# Try more requests (should get 503)
curl -v http://localhost:8080/api/data
# Wait for server to stop
@@ -349,101 +476,13 @@ wait $SERVER_PID
echo "Server stopped gracefully"
```
### Expected Output
```
Starting server on :8080
Received signal: terminated, initiating graceful shutdown
Starting graceful shutdown...
Waiting for 8 in-flight requests to complete...
Waiting for 4 in-flight requests to complete...
Waiting for 1 in-flight requests to complete...
All requests drained in 2.3s
Cleanup: Flushing metrics...
Cleanup: Closing database...
Shutting down HTTP server...
Graceful shutdown complete
Server stopped
```
## Monitoring In-Flight Requests
```go
// Get current in-flight count
count := srv.InFlightRequests()
fmt.Printf("In-flight requests: %d\n", count)
// Check if shutting down
if srv.IsShuttingDown() {
fmt.Println("Server is shutting down")
}
```
## Advanced Usage
### Custom Shutdown Logic
```go
// Implement custom shutdown
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
log.Println("Shutdown signal received")
// Custom pre-shutdown logic
log.Println("Running custom cleanup...")
// Shutdown with callbacks
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.ShutdownWithCallbacks(ctx); err != nil {
log.Printf("Shutdown error: %v", err)
}
}()
// Start server
srv.server.ListenAndServe()
```
### Multiple Servers
```go
// HTTP server
httpSrv := server.NewGracefulServer(server.Config{
Addr: ":8080",
Handler: httpRouter,
})
// HTTPS server
httpsSrv := server.NewGracefulServer(server.Config{
Addr: ":8443",
Handler: httpsRouter,
})
// Start both
go httpSrv.ListenAndServe()
go httpsSrv.ListenAndServe()
// Shutdown both on signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
<-sigChan
ctx := context.Background()
httpSrv.Shutdown(ctx)
httpsSrv.Shutdown(ctx)
```
## Best Practices
1. **Set appropriate timeouts**
- `DrainTimeout` < `ShutdownTimeout`
- `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds`
2. **Register cleanup callbacks** for:
2. **Use shutdown callbacks** for:
- Database connections
- Message queues
- Metrics flushing
@@ -458,7 +497,12 @@ httpsSrv.Shutdown(ctx)
- Set `preStop` hook in Kubernetes (5-10s delay)
- Allows load balancer to deregister before shutdown
5. **Monitoring**
5. **HTTPS in production**
- Use AutoTLS for public-facing services
- Use certificate files for enterprise PKI
- Use self-signed only for development/testing
6. **Monitoring**
- Track in-flight requests in metrics
- Alert on slow drains
- Monitor shutdown duration
@@ -470,24 +514,63 @@ httpsSrv.Shutdown(ctx)
```go
// Increase drain timeout
config.DrainTimeout = 60 * time.Second
config.ShutdownTimeout = 65 * time.Second
```
### Requests Still Timing Out
### Requests Timing Out
```go
// Increase write timeout
config.WriteTimeout = 30 * time.Second
```
### Force Shutdown Not Working
The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed.
### Debugging Shutdown
### Certificate Issues
```go
// Verify certificate files exist and are readable
if _, err := os.Stat(config.SSLCert); err != nil {
log.Fatalf("Certificate not found: %v", err)
}
// For AutoTLS, ensure:
// - Port 443 is accessible
// - Domains resolve to server IP
// - Cache directory is writable
```
### Debug Logging
```go
// Enable debug logging
import "github.com/bitechdev/ResolveSpec/pkg/logger"
// Enable debug logging
logger.SetLevel("debug")
```
## API Reference
### Manager Methods
- `NewManager()` - Create new server manager
- `Add(cfg Config)` - Register server instance
- `Get(name string)` - Get server by name
- `Remove(name string)` - Stop and remove server
- `StartAll()` - Start all registered servers
- `StopAll()` - Stop all servers gracefully
- `StopAllWithContext(ctx)` - Stop with timeout
- `RestartAll()` - Restart all servers
- `List()` - Get all server instances
- `ServeWithGracefulShutdown()` - Start and block until shutdown
- `RegisterShutdownCallback(cb)` - Register cleanup function
### Instance Methods
- `Start()` - Start the server
- `Stop(ctx)` - Stop gracefully
- `Addr()` - Get server address
- `Name()` - Get server name
- `HealthCheckHandler()` - Get health handler
- `ReadinessHandler()` - Get readiness handler
- `InFlightRequests()` - Get in-flight count
- `IsShuttingDown()` - Check shutdown status
- `Wait()` - Block until shutdown complete

294
pkg/server/example_test.go Normal file
View File

@@ -0,0 +1,294 @@
package server_test
import (
"context"
"fmt"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/server"
)
// ExampleManager_basic demonstrates basic server manager usage
func ExampleManager_basic() {
// Create a server manager
mgr := server.NewManager()
// Define a simple handler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "Hello from server!")
})
// Add an HTTP server
_, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: handler,
GZIP: true, // Enable GZIP compression
})
if err != nil {
panic(err)
}
// Start all servers
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Server is now running...
// When done, stop gracefully
if err := mgr.StopAll(); err != nil {
panic(err)
}
}
// ExampleManager_https demonstrates HTTPS configurations
func ExampleManager_https() {
mgr := server.NewManager()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Secure connection!")
})
// Option 1: Use certificate files
_, err := mgr.Add(server.Config{
Name: "https-server-files",
Host: "localhost",
Port: 8443,
Handler: handler,
SSLCert: "/path/to/cert.pem",
SSLKey: "/path/to/key.pem",
})
if err != nil {
panic(err)
}
// Option 2: Self-signed certificate (for development)
_, err = mgr.Add(server.Config{
Name: "https-server-self-signed",
Host: "localhost",
Port: 8444,
Handler: handler,
SelfSignedSSL: true,
})
if err != nil {
panic(err)
}
// Option 3: Let's Encrypt / AutoTLS (for production)
_, err = mgr.Add(server.Config{
Name: "https-server-letsencrypt",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com",
AutoTLSCacheDir: "./certs-cache",
})
if err != nil {
panic(err)
}
// Start all servers
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Cleanup
mgr.StopAll()
}
// ExampleManager_gracefulShutdown demonstrates graceful shutdown with callbacks
func ExampleManager_gracefulShutdown() {
mgr := server.NewManager()
// Register shutdown callbacks for cleanup tasks
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
fmt.Println("Closing database connections...")
// Close your database here
return nil
})
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
fmt.Println("Flushing metrics...")
// Flush metrics here
return nil
})
// Add server with custom timeouts
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate some work
time.Sleep(100 * time.Millisecond)
fmt.Fprintln(w, "Done!")
})
_, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: handler,
ShutdownTimeout: 30 * time.Second, // Max time for shutdown
DrainTimeout: 25 * time.Second, // Time to wait for in-flight requests
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
})
if err != nil {
panic(err)
}
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
// This will automatically handle graceful shutdown with callbacks
if err := mgr.ServeWithGracefulShutdown(); err != nil {
fmt.Printf("Shutdown completed: %v\n", err)
}
}
// ExampleManager_healthChecks demonstrates health and readiness endpoints
func ExampleManager_healthChecks() {
mgr := server.NewManager()
// Create a router with health endpoints
mux := http.NewServeMux()
mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Data endpoint")
})
// Add server
instance, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: mux,
})
if err != nil {
panic(err)
}
// Add health and readiness endpoints
mux.HandleFunc("/health", instance.HealthCheckHandler())
mux.HandleFunc("/ready", instance.ReadinessHandler())
// Start the server
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Health check returns:
// - 200 OK with {"status":"healthy"} when healthy
// - 503 Service Unavailable with {"status":"shutting_down"} when shutting down
// Readiness check returns:
// - 200 OK with {"ready":true,"in_flight_requests":N} when ready
// - 503 Service Unavailable with {"ready":false,"reason":"shutting_down"} when shutting down
// Cleanup
mgr.StopAll()
}
// ExampleManager_multipleServers demonstrates running multiple servers
func ExampleManager_multipleServers() {
mgr := server.NewManager()
// Public API server
publicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Public API")
})
_, err := mgr.Add(server.Config{
Name: "public-api",
Host: "0.0.0.0",
Port: 8080,
Handler: publicHandler,
GZIP: true,
})
if err != nil {
panic(err)
}
// Admin API server (different port)
adminHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Admin API")
})
_, err = mgr.Add(server.Config{
Name: "admin-api",
Host: "localhost",
Port: 8081,
Handler: adminHandler,
})
if err != nil {
panic(err)
}
// Metrics server (internal only)
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Metrics data")
})
_, err = mgr.Add(server.Config{
Name: "metrics",
Host: "127.0.0.1",
Port: 9090,
Handler: metricsHandler,
})
if err != nil {
panic(err)
}
// Start all servers at once
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Get specific server instance
publicInstance, err := mgr.Get("public-api")
if err != nil {
panic(err)
}
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
// List all servers
instances := mgr.List()
fmt.Printf("Running %d servers\n", len(instances))
// Stop all servers gracefully (in parallel)
if err := mgr.StopAll(); err != nil {
panic(err)
}
}
// ExampleManager_monitoring demonstrates monitoring server state
func ExampleManager_monitoring() {
mgr := server.NewManager()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(50 * time.Millisecond) // Simulate work
fmt.Fprintln(w, "Done")
})
instance, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: handler,
})
if err != nil {
panic(err)
}
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Check server status
fmt.Printf("Server address: %s\n", instance.Addr())
fmt.Printf("Server name: %s\n", instance.Name())
fmt.Printf("Is shutting down: %v\n", instance.IsShuttingDown())
fmt.Printf("In-flight requests: %d\n", instance.InFlightRequests())
// Cleanup
mgr.StopAll()
// Wait for complete shutdown
instance.Wait()
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"net/http"
"time"
)
// Config holds the configuration for a single web server instance.
@@ -11,11 +12,52 @@ type Config struct {
Host string
Port int
Description string
SSLCert string
SSLKey string
GZIP bool
// Handler is the http.Handler (e.g., a router) to be served.
Handler http.Handler
// GZIP compression support
GZIP bool
// TLS/HTTPS configuration options (mutually exclusive)
// Option 1: Provide certificate and key files directly
SSLCert string
SSLKey string
// Option 2: Use self-signed certificate (for development/testing)
// Generates a self-signed certificate automatically if no SSLCert/SSLKey provided
SelfSignedSSL bool
// Option 3: Use Let's Encrypt / Certbot for automatic TLS
// AutoTLS enables automatic certificate management via Let's Encrypt
AutoTLS bool
// AutoTLSDomains specifies the domains for Let's Encrypt certificates
AutoTLSDomains []string
// AutoTLSCacheDir specifies where to cache certificates (default: "./certs-cache")
AutoTLSCacheDir string
// AutoTLSEmail is the email for Let's Encrypt registration (optional but recommended)
AutoTLSEmail string
// Graceful shutdown configuration
// ShutdownTimeout is the maximum time to wait for graceful shutdown
// Default: 30 seconds
ShutdownTimeout time.Duration
// DrainTimeout is the time to wait for in-flight requests to complete
// before forcing shutdown. Default: 25 seconds
DrainTimeout time.Duration
// ReadTimeout is the maximum duration for reading the entire request
// Default: 15 seconds
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out writes of the response
// Default: 15 seconds
WriteTimeout time.Duration
// IdleTimeout is the maximum amount of time to wait for the next request
// Default: 60 seconds
IdleTimeout time.Duration
}
// Instance defines the interface for a single server instance.
@@ -24,11 +66,33 @@ type Instance interface {
// Start begins serving requests. This method should be non-blocking and
// run the server in a separate goroutine.
Start() error
// Stop gracefully shuts down the server without interrupting any active connections.
// It accepts a context to allow for a timeout.
Stop(ctx context.Context) error
// Addr returns the network address the server is listening on.
Addr() string
// Name returns the server instance name.
Name() string
// HealthCheckHandler returns a handler that responds to health checks.
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down.
HealthCheckHandler() http.HandlerFunc
// ReadinessHandler returns a handler for readiness checks.
// Includes in-flight request count.
ReadinessHandler() http.HandlerFunc
// InFlightRequests returns the current number of in-flight requests.
InFlightRequests() int64
// IsShuttingDown returns true if the server is shutting down.
IsShuttingDown() bool
// Wait blocks until shutdown is complete.
Wait()
}
// Manager defines the interface for a server manager.
@@ -37,16 +101,37 @@ type Manager interface {
// Add registers a new server instance based on the provided configuration.
// The server is not started until StartAll or Start is called on the instance.
Add(cfg Config) (Instance, error)
// Get returns a server instance by its name.
Get(name string) (Instance, error)
// Remove stops and removes a server instance by its name.
Remove(name string) error
// StartAll starts all registered server instances that are not already running.
StartAll() error
// StopAll gracefully shuts down all running server instances.
// Executes shutdown callbacks and drains in-flight requests.
StopAll() error
// StopAllWithContext gracefully shuts down all running server instances with a context.
StopAllWithContext(ctx context.Context) error
// RestartAll gracefully restarts all running server instances.
RestartAll() error
// List returns all registered server instances.
List() []Instance
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
// It handles SIGINT and SIGTERM signals and performs graceful shutdown with callbacks.
ServeWithGracefulShutdown() error
// RegisterShutdownCallback registers a callback to be called during shutdown.
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
RegisterShutdownCallback(cb ShutdownCallback)
}
// ShutdownCallback is a function called during graceful shutdown.
type ShutdownCallback func(context.Context) error

View File

@@ -4,26 +4,173 @@ import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/klauspost/compress/gzhttp"
"golang.org/x/net/http2"
)
// serverManager manages a collection of server instances.
// gracefulServer wraps http.Server with graceful shutdown capabilities (internal type)
type gracefulServer struct {
server *http.Server
shutdownTimeout time.Duration
drainTimeout time.Duration
inFlightRequests atomic.Int64
isShuttingDown atomic.Bool
shutdownOnce sync.Once
shutdownComplete chan struct{}
}
// trackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
func (gs *gracefulServer) trackRequestsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if shutting down
if gs.isShuttingDown.Load() {
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
return
}
// Increment in-flight counter
gs.inFlightRequests.Add(1)
defer gs.inFlightRequests.Add(-1)
// Serve the request
next.ServeHTTP(w, r)
})
}
// shutdown performs graceful shutdown with request draining
func (gs *gracefulServer) shutdown(ctx context.Context) error {
var shutdownErr error
gs.shutdownOnce.Do(func() {
logger.Info("Starting graceful shutdown...")
// Mark as shutting down (new requests will be rejected)
gs.isShuttingDown.Store(true)
// Create context with timeout
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
defer cancel()
// Wait for in-flight requests to complete (with drain timeout)
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
defer drainCancel()
shutdownErr = gs.drainRequests(drainCtx)
if shutdownErr != nil {
logger.Error("Error draining requests: %v", shutdownErr)
}
// Shutdown the server
logger.Info("Shutting down HTTP server...")
if err := gs.server.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down server: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
logger.Info("Graceful shutdown complete")
close(gs.shutdownComplete)
})
return shutdownErr
}
// drainRequests waits for in-flight requests to complete
func (gs *gracefulServer) drainRequests(ctx context.Context) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
startTime := time.Now()
for {
inFlight := gs.inFlightRequests.Load()
if inFlight == 0 {
logger.Info("All requests drained in %v", time.Since(startTime))
return nil
}
select {
case <-ctx.Done():
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
case <-ticker.C:
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
}
}
}
// inFlightRequests returns the current number of in-flight requests
func (gs *gracefulServer) inFlightRequestsCount() int64 {
return gs.inFlightRequests.Load()
}
// isShutdown returns true if the server is shutting down
func (gs *gracefulServer) isShutdown() bool {
return gs.isShuttingDown.Load()
}
// wait blocks until shutdown is complete
func (gs *gracefulServer) wait() {
<-gs.shutdownComplete
}
// healthCheckHandler returns a handler that responds to health checks
func (gs *gracefulServer) healthCheckHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.isShutdown() {
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(`{"status":"healthy"}`))
if err != nil {
logger.Warn("Failed to write health check response: %v", err)
}
}
}
// readinessHandler returns a handler for readiness checks
func (gs *gracefulServer) readinessHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.isShutdown() {
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
inFlight := gs.inFlightRequestsCount()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
}
}
// serverManager manages a collection of server instances with graceful shutdown support.
type serverManager struct {
instances map[string]Instance
mu sync.RWMutex
instances map[string]Instance
mu sync.RWMutex
shutdownCallbacks []ShutdownCallback
callbacksMu sync.Mutex
}
// NewManager creates a new server manager.
func NewManager() Manager {
return &serverManager{
instances: make(map[string]Instance),
instances: make(map[string]Instance),
shutdownCallbacks: make([]ShutdownCallback, 0),
}
}
@@ -74,7 +221,7 @@ func (sm *serverManager) Remove(name string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := instance.Stop(ctx); err != nil {
logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err, context.Background())
logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err)
}
delete(sm.instances, name)
@@ -94,7 +241,6 @@ func (sm *serverManager) StartAll() error {
}
if len(startErrors) > 0 {
// In a real-world scenario, you might want a more sophisticated error handling strategy
return fmt.Errorf("encountered errors while starting servers: %v", startErrors)
}
return nil
@@ -102,6 +248,11 @@ func (sm *serverManager) StartAll() error {
// StopAll gracefully shuts down all running server instances.
func (sm *serverManager) StopAll() error {
return sm.StopAllWithContext(context.Background())
}
// StopAllWithContext gracefully shuts down all running server instances with a context.
func (sm *serverManager) StopAllWithContext(ctx context.Context) error {
sm.mu.RLock()
instancesToStop := make([]Instance, 0, len(sm.instances))
for _, instance := range sm.instances {
@@ -109,19 +260,38 @@ func (sm *serverManager) StopAll() error {
}
sm.mu.RUnlock()
logger.Info("Shutting down all servers...", context.Background())
logger.Info("Shutting down all servers...")
// Execute shutdown callbacks first
sm.callbacksMu.Lock()
callbacks := make([]ShutdownCallback, len(sm.shutdownCallbacks))
copy(callbacks, sm.shutdownCallbacks)
sm.callbacksMu.Unlock()
if len(callbacks) > 0 {
logger.Info("Executing %d shutdown callbacks...", len(callbacks))
for i, cb := range callbacks {
if err := cb(ctx); err != nil {
logger.Error("Shutdown callback %d failed: %v", i+1, err)
}
}
}
// Stop all instances in parallel
var shutdownErrors []error
var wg sync.WaitGroup
var errorsMu sync.Mutex
for _, instance := range instancesToStop {
wg.Add(1)
go func(inst Instance) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
if err := inst.Stop(ctx); err != nil {
shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Addr(), err))
if err := inst.Stop(shutdownCtx); err != nil {
errorsMu.Lock()
shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Name(), err))
errorsMu.Unlock()
}
}(instance)
}
@@ -131,13 +301,13 @@ func (sm *serverManager) StopAll() error {
if len(shutdownErrors) > 0 {
return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors)
}
logger.Info("All servers stopped gracefully.", context.Background())
logger.Info("All servers stopped gracefully.")
return nil
}
// RestartAll gracefully restarts all running server instances.
func (sm *serverManager) RestartAll() error {
logger.Info("Restarting all servers...", context.Background())
logger.Info("Restarting all servers...")
if err := sm.StopAll(); err != nil {
return fmt.Errorf("failed to stop servers during restart: %w", err)
}
@@ -148,7 +318,7 @@ func (sm *serverManager) RestartAll() error {
if err := sm.StartAll(); err != nil {
return fmt.Errorf("failed to start servers during restart: %w", err)
}
logger.Info("All servers restarted successfully.", context.Background())
logger.Info("All servers restarted successfully.")
return nil
}
@@ -164,13 +334,46 @@ func (sm *serverManager) List() []Instance {
return instances
}
// RegisterShutdownCallback registers a callback to be called during shutdown.
func (sm *serverManager) RegisterShutdownCallback(cb ShutdownCallback) {
sm.callbacksMu.Lock()
defer sm.callbacksMu.Unlock()
sm.shutdownCallbacks = append(sm.shutdownCallbacks, cb)
}
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
func (sm *serverManager) ServeWithGracefulShutdown() error {
// Start all servers
if err := sm.StartAll(); err != nil {
return fmt.Errorf("failed to start servers: %w", err)
}
logger.Info("All servers started. Waiting for shutdown signal...")
// Wait for interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
sig := <-sigChan
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
// Create context with timeout for shutdown
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return sm.StopAllWithContext(ctx)
}
// serverInstance is a concrete implementation of the Instance interface.
// It wraps gracefulServer to provide graceful shutdown capabilities.
type serverInstance struct {
cfg Config
httpServer *http.Server
mu sync.RWMutex
running bool
stopCh chan struct{}
cfg Config
gracefulServer *gracefulServer
certFile string // Path to certificate file (may be temporary for self-signed)
keyFile string // Path to key file (may be temporary for self-signed)
mu sync.RWMutex
running bool
serverErr chan error
}
// newInstance creates a new, unstarted server instance from a config.
@@ -179,12 +382,29 @@ func newInstance(cfg Config) (*serverInstance, error) {
return nil, fmt.Errorf("handler cannot be nil")
}
// Set default timeouts
if cfg.ShutdownTimeout == 0 {
cfg.ShutdownTimeout = 30 * time.Second
}
if cfg.DrainTimeout == 0 {
cfg.DrainTimeout = 25 * time.Second
}
if cfg.ReadTimeout == 0 {
cfg.ReadTimeout = 15 * time.Second
}
if cfg.WriteTimeout == 0 {
cfg.WriteTimeout = 15 * time.Second
}
if cfg.IdleTimeout == 0 {
cfg.IdleTimeout = 60 * time.Second
}
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
var handler http.Handler = cfg.Handler
// Wrap with GZIP handler if enabled
if cfg.GZIP {
gz, err := gzhttp.NewWrapper(gzhttp.BestSpeed)
gz, err := gzhttp.NewWrapper()
if err != nil {
return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err)
}
@@ -194,20 +414,33 @@ func newInstance(cfg Config) (*serverInstance, error) {
// Wrap with the panic recovery middleware
handler = middleware.PanicRecovery(handler)
// Here you could add other default middleware like request logging, metrics, etc.
// Configure TLS if any TLS option is enabled
tlsConfig, certFile, keyFile, err := configureTLS(cfg)
if err != nil {
return nil, fmt.Errorf("failed to configure TLS: %w", err)
}
httpServer := &http.Server{
Addr: addr,
Handler: handler,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
// Create gracefulServer
gracefulSrv := &gracefulServer{
server: &http.Server{
Addr: addr,
Handler: handler,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
IdleTimeout: cfg.IdleTimeout,
TLSConfig: tlsConfig,
},
shutdownTimeout: cfg.ShutdownTimeout,
drainTimeout: cfg.DrainTimeout,
shutdownComplete: make(chan struct{}),
}
return &serverInstance{
cfg: cfg,
httpServer: httpServer,
stopCh: make(chan struct{}),
cfg: cfg,
gracefulServer: gracefulSrv,
certFile: certFile,
keyFile: keyFile,
serverErr: make(chan error, 1),
}, nil
}
@@ -220,42 +453,69 @@ func (s *serverInstance) Start() error {
return fmt.Errorf("server '%s' is already running", s.cfg.Name)
}
hasSSL := s.cfg.SSLCert != "" && s.cfg.SSLKey != ""
// Determine if we're using TLS
useTLS := s.cfg.SSLCert != "" || s.cfg.SSLKey != "" || s.cfg.SelfSignedSSL || s.cfg.AutoTLS
// Wrap handler with request tracking
s.gracefulServer.server.Handler = s.gracefulServer.trackRequestsMiddleware(s.gracefulServer.server.Handler)
go func() {
defer func() {
s.mu.Lock()
s.running = false
s.mu.Unlock()
logger.Info("Server '%s' stopped.", s.cfg.Name, context.Background())
logger.Info("Server '%s' stopped.", s.cfg.Name)
}()
var err error
protocol := "HTTP"
if hasSSL {
if useTLS {
protocol = "HTTPS"
// Configure TLS + HTTP/2
s.httpServer.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
// For AutoTLS, we need to use a TLS listener
if s.cfg.AutoTLS {
// Create listener
ln, lnErr := net.Listen("tcp", s.gracefulServer.server.Addr)
if lnErr != nil {
err = fmt.Errorf("failed to create listener: %w", lnErr)
} else {
// Wrap with TLS
tlsListener := tls.NewListener(ln, s.gracefulServer.server.TLSConfig)
err = s.gracefulServer.server.Serve(tlsListener)
}
} else {
// Use certificate files (regular SSL or self-signed)
err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile)
}
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background()) err = s.httpServer.ListenAndServeTLS(s.cfg.SSLCert, s.cfg.SSLKey)
} else {
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background())
err = s.httpServer.ListenAndServe()
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
err = s.gracefulServer.server.ListenAndServe()
}
// If the server stopped for a reason other than a graceful shutdown, log the error.
// If the server stopped for a reason other than a graceful shutdown, log and report the error.
if err != nil && err != http.ErrServerClosed {
logger.Error("Server '%s' failed: %v", s.cfg.Name, err, context.Background())
logger.Error("Server '%s' failed: %v", s.cfg.Name, err)
select {
case s.serverErr <- err:
default:
}
}
}()
s.running = true
// A small delay to allow the goroutine to start and potentially fail on binding.
// A more robust solution might involve a channel signal.
time.Sleep(50 * time.Millisecond)
// Check if the server failed to start
select {
case err := <-s.serverErr:
s.running = false
return err
default:
}
return nil
}
@@ -269,7 +529,7 @@ func (s *serverInstance) Stop(ctx context.Context) error {
}
logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name)
err := s.httpServer.Shutdown(ctx)
err := s.gracefulServer.shutdown(ctx)
if err == nil {
s.running = false
}
@@ -278,5 +538,35 @@ func (s *serverInstance) Stop(ctx context.Context) error {
// Addr returns the network address the server is listening on.
func (s *serverInstance) Addr() string {
return s.httpServer.Addr
return s.gracefulServer.server.Addr
}
// Name returns the server instance name.
func (s *serverInstance) Name() string {
return s.cfg.Name
}
// HealthCheckHandler returns a handler that responds to health checks.
func (s *serverInstance) HealthCheckHandler() http.HandlerFunc {
return s.gracefulServer.healthCheckHandler()
}
// ReadinessHandler returns a handler for readiness checks.
func (s *serverInstance) ReadinessHandler() http.HandlerFunc {
return s.gracefulServer.readinessHandler()
}
// InFlightRequests returns the current number of in-flight requests.
func (s *serverInstance) InFlightRequests() int64 {
return s.gracefulServer.inFlightRequestsCount()
}
// IsShuttingDown returns true if the server is shutting down.
func (s *serverInstance) IsShuttingDown() bool {
return s.gracefulServer.isShutdown()
}
// Wait blocks until shutdown is complete.
func (s *serverInstance) Wait() {
s.gracefulServer.wait()
}

View File

@@ -1,10 +1,12 @@
package server
import (
"context"
"fmt"
"io"
"net"
"net/http"
"sync"
"testing"
"time"
@@ -123,3 +125,204 @@ func TestManagerErrorCases(t *testing.T) {
_, err = sm.Add(config3)
require.Error(t, err, "should not be able to add a server with a nil handler")
}
func TestGracefulShutdown(t *testing.T) {
logger.Init(true)
sm := NewManager()
requestsHandled := 0
var requestsMu sync.Mutex
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsMu.Lock()
requestsHandled++
requestsMu.Unlock()
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
})
testPort := getFreePort(t)
instance, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: handler,
DrainTimeout: 2 * time.Second,
})
require.NoError(t, err)
err = sm.StartAll()
require.NoError(t, err)
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Send some concurrent requests
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
client := &http.Client{Timeout: 5 * time.Second}
url := fmt.Sprintf("http://localhost:%d", testPort)
resp, err := client.Get(url)
if err == nil {
resp.Body.Close()
}
}()
}
// Wait a bit for requests to start
time.Sleep(50 * time.Millisecond)
// Check in-flight requests
inFlight := instance.InFlightRequests()
assert.Greater(t, inFlight, int64(0), "Should have in-flight requests")
// Stop the server
err = sm.StopAll()
require.NoError(t, err)
// Wait for all requests to complete
wg.Wait()
// Verify all requests were handled
requestsMu.Lock()
handled := requestsHandled
requestsMu.Unlock()
assert.GreaterOrEqual(t, handled, 1, "At least some requests should have been handled")
// Verify no in-flight requests
assert.Equal(t, int64(0), instance.InFlightRequests(), "Should have no in-flight requests after shutdown")
}
func TestHealthAndReadinessEndpoints(t *testing.T) {
logger.Init(true)
sm := NewManager()
mux := http.NewServeMux()
testPort := getFreePort(t)
instance, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: mux,
})
require.NoError(t, err)
// Add health and readiness endpoints
mux.HandleFunc("/health", instance.HealthCheckHandler())
mux.HandleFunc("/ready", instance.ReadinessHandler())
err = sm.StartAll()
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
client := &http.Client{Timeout: 2 * time.Second}
baseURL := fmt.Sprintf("http://localhost:%d", testPort)
// Test health endpoint
resp, err := client.Get(baseURL + "/health")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
assert.Contains(t, string(body), "healthy")
// Test readiness endpoint
resp, err = client.Get(baseURL + "/ready")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, _ = io.ReadAll(resp.Body)
resp.Body.Close()
assert.Contains(t, string(body), "ready")
assert.Contains(t, string(body), "in_flight_requests")
// Stop the server
sm.StopAll()
}
func TestRequestRejectionDuringShutdown(t *testing.T) {
logger.Init(true)
sm := NewManager()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(50 * time.Millisecond)
w.WriteHeader(http.StatusOK)
})
testPort := getFreePort(t)
_, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: handler,
DrainTimeout: 1 * time.Second,
})
require.NoError(t, err)
err = sm.StartAll()
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
// Start shutdown in background
go func() {
time.Sleep(50 * time.Millisecond)
sm.StopAll()
}()
// Give shutdown time to start
time.Sleep(100 * time.Millisecond)
// Try to make a request after shutdown started
client := &http.Client{Timeout: 2 * time.Second}
url := fmt.Sprintf("http://localhost:%d", testPort)
resp, err := client.Get(url)
// The request should either fail (connection refused) or get 503
if err == nil {
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Should get 503 during shutdown")
resp.Body.Close()
}
}
func TestShutdownCallbacks(t *testing.T) {
logger.Init(true)
sm := NewManager()
callbackExecuted := false
var callbackMu sync.Mutex
sm.RegisterShutdownCallback(func(ctx context.Context) error {
callbackMu.Lock()
callbackExecuted = true
callbackMu.Unlock()
return nil
})
testPort := getFreePort(t)
_, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: http.NewServeMux(),
})
require.NoError(t, err)
err = sm.StartAll()
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = sm.StopAll()
require.NoError(t, err)
callbackMu.Lock()
executed := callbackExecuted
callbackMu.Unlock()
assert.True(t, executed, "Shutdown callback should have been executed")
}

View File

@@ -1,296 +0,0 @@
package server
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// GracefulServer wraps http.Server with graceful shutdown capabilities
type GracefulServer struct {
server *http.Server
shutdownTimeout time.Duration
drainTimeout time.Duration
inFlightRequests atomic.Int64
isShuttingDown atomic.Bool
shutdownOnce sync.Once
shutdownComplete chan struct{}
}
// Config holds configuration for the graceful server
type Config struct {
// Addr is the server address (e.g., ":8080")
Addr string
// Handler is the HTTP handler
Handler http.Handler
// ShutdownTimeout is the maximum time to wait for graceful shutdown
// Default: 30 seconds
ShutdownTimeout time.Duration
// DrainTimeout is the time to wait for in-flight requests to complete
// before forcing shutdown. Default: 25 seconds
DrainTimeout time.Duration
// ReadTimeout is the maximum duration for reading the entire request
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out writes of the response
WriteTimeout time.Duration
// IdleTimeout is the maximum amount of time to wait for the next request
IdleTimeout time.Duration
}
// NewGracefulServer creates a new graceful server
func NewGracefulServer(config Config) *GracefulServer {
if config.ShutdownTimeout == 0 {
config.ShutdownTimeout = 30 * time.Second
}
if config.DrainTimeout == 0 {
config.DrainTimeout = 25 * time.Second
}
if config.ReadTimeout == 0 {
config.ReadTimeout = 10 * time.Second
}
if config.WriteTimeout == 0 {
config.WriteTimeout = 10 * time.Second
}
if config.IdleTimeout == 0 {
config.IdleTimeout = 120 * time.Second
}
gs := &GracefulServer{
server: &http.Server{
Addr: config.Addr,
Handler: config.Handler,
ReadTimeout: config.ReadTimeout,
WriteTimeout: config.WriteTimeout,
IdleTimeout: config.IdleTimeout,
},
shutdownTimeout: config.ShutdownTimeout,
drainTimeout: config.DrainTimeout,
shutdownComplete: make(chan struct{}),
}
return gs
}
// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if shutting down
if gs.isShuttingDown.Load() {
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
return
}
// Increment in-flight counter
gs.inFlightRequests.Add(1)
defer gs.inFlightRequests.Add(-1)
// Serve the request
next.ServeHTTP(w, r)
})
}
// ListenAndServe starts the server and handles graceful shutdown
func (gs *GracefulServer) ListenAndServe() error {
// Wrap handler with request tracking
gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler)
// Start server in goroutine
serverErr := make(chan error, 1)
go func() {
logger.Info("Starting server on %s", gs.server.Addr)
if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
serverErr <- err
}
close(serverErr)
}()
// Wait for interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
select {
case err := <-serverErr:
return err
case sig := <-sigChan:
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
return gs.Shutdown(context.Background())
}
}
// Shutdown performs graceful shutdown with request draining
func (gs *GracefulServer) Shutdown(ctx context.Context) error {
var shutdownErr error
gs.shutdownOnce.Do(func() {
logger.Info("Starting graceful shutdown...")
// Mark as shutting down (new requests will be rejected)
gs.isShuttingDown.Store(true)
// Create context with timeout
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
defer cancel()
// Wait for in-flight requests to complete (with drain timeout)
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
defer drainCancel()
shutdownErr = gs.drainRequests(drainCtx)
if shutdownErr != nil {
logger.Error("Error draining requests: %v", shutdownErr)
}
// Shutdown the server
logger.Info("Shutting down HTTP server...")
if err := gs.server.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down server: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
logger.Info("Graceful shutdown complete")
close(gs.shutdownComplete)
})
return shutdownErr
}
// drainRequests waits for in-flight requests to complete
func (gs *GracefulServer) drainRequests(ctx context.Context) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
startTime := time.Now()
for {
inFlight := gs.inFlightRequests.Load()
if inFlight == 0 {
logger.Info("All requests drained in %v", time.Since(startTime))
return nil
}
select {
case <-ctx.Done():
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
case <-ticker.C:
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
}
}
}
// InFlightRequests returns the current number of in-flight requests
func (gs *GracefulServer) InFlightRequests() int64 {
return gs.inFlightRequests.Load()
}
// IsShuttingDown returns true if the server is shutting down
func (gs *GracefulServer) IsShuttingDown() bool {
return gs.isShuttingDown.Load()
}
// Wait blocks until shutdown is complete
func (gs *GracefulServer) Wait() {
<-gs.shutdownComplete
}
// HealthCheckHandler returns a handler that responds to health checks
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down
func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.IsShuttingDown() {
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(`{"status":"healthy"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
}
}
// ReadinessHandler returns a handler for readiness checks
// Includes in-flight request count
func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.IsShuttingDown() {
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
inFlight := gs.InFlightRequests()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
}
}
// ShutdownCallback is a function called during shutdown
type ShutdownCallback func(context.Context) error
// shutdownCallbacks stores registered shutdown callbacks
var (
shutdownCallbacks []ShutdownCallback
shutdownCallbacksMu sync.Mutex
)
// RegisterShutdownCallback registers a callback to be called during shutdown
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
func RegisterShutdownCallback(cb ShutdownCallback) {
shutdownCallbacksMu.Lock()
defer shutdownCallbacksMu.Unlock()
shutdownCallbacks = append(shutdownCallbacks, cb)
}
// executeShutdownCallbacks runs all registered shutdown callbacks
func executeShutdownCallbacks(ctx context.Context) error {
shutdownCallbacksMu.Lock()
callbacks := make([]ShutdownCallback, len(shutdownCallbacks))
copy(callbacks, shutdownCallbacks)
shutdownCallbacksMu.Unlock()
var errors []error
for i, cb := range callbacks {
logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks))
if err := cb(ctx); err != nil {
logger.Error("Shutdown callback %d failed: %v", i+1, err)
errors = append(errors, err)
}
}
if len(errors) > 0 {
return fmt.Errorf("shutdown callbacks failed: %v", errors)
}
return nil
}
// ShutdownWithCallbacks performs shutdown and executes all registered callbacks
func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error {
// Execute callbacks first
if err := executeShutdownCallbacks(ctx); err != nil {
logger.Error("Error executing shutdown callbacks: %v", err)
}
// Then shutdown the server
return gs.Shutdown(ctx)
}

View File

@@ -1,231 +0,0 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
)
func TestGracefulServerTrackRequests(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}),
})
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
// Start some requests
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}()
}
// Wait a bit for requests to start
time.Sleep(10 * time.Millisecond)
// Check in-flight count
inFlight := srv.InFlightRequests()
if inFlight == 0 {
t.Error("Should have in-flight requests")
}
// Wait for all requests to complete
wg.Wait()
// Check that counter is back to zero
inFlight = srv.InFlightRequests()
if inFlight != 0 {
t.Errorf("In-flight requests should be 0, got %d", inFlight)
}
}
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
})
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
// Mark as shutting down
srv.isShuttingDown.Store(true)
// Try to make a request
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Should get 503
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
}
func TestHealthCheckHandler(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
})
handler := srv.HealthCheckHandler()
// Healthy
t.Run("Healthy", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
if w.Body.String() != `{"status":"healthy"}` {
t.Errorf("Unexpected body: %s", w.Body.String())
}
})
// Shutting down
t.Run("ShuttingDown", func(t *testing.T) {
srv.isShuttingDown.Store(true)
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
})
}
func TestReadinessHandler(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
})
handler := srv.ReadinessHandler()
// Ready with no in-flight requests
t.Run("Ready", func(t *testing.T) {
req := httptest.NewRequest("GET", "/ready", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
body := w.Body.String()
if body != `{"ready":true,"in_flight_requests":0}` {
t.Errorf("Unexpected body: %s", body)
}
})
// Not ready during shutdown
t.Run("NotReady", func(t *testing.T) {
srv.isShuttingDown.Store(true)
req := httptest.NewRequest("GET", "/ready", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
})
}
func TestShutdownCallbacks(t *testing.T) {
callbackExecuted := false
RegisterShutdownCallback(func(ctx context.Context) error {
callbackExecuted = true
return nil
})
ctx := context.Background()
err := executeShutdownCallbacks(ctx)
if err != nil {
t.Errorf("executeShutdownCallbacks() error = %v", err)
}
if !callbackExecuted {
t.Error("Shutdown callback was not executed")
}
// Reset for other tests
shutdownCallbacks = nil
}
func TestDrainRequests(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
DrainTimeout: 1 * time.Second,
})
// Simulate in-flight requests
srv.inFlightRequests.Add(3)
// Start draining in background
go func() {
time.Sleep(100 * time.Millisecond)
// Simulate requests completing
srv.inFlightRequests.Add(-3)
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := srv.drainRequests(ctx)
if err != nil {
t.Errorf("drainRequests() error = %v", err)
}
if srv.InFlightRequests() != 0 {
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
}
}
func TestDrainRequestsTimeout(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
DrainTimeout: 100 * time.Millisecond,
})
// Simulate in-flight requests that don't complete
srv.inFlightRequests.Add(5)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := srv.drainRequests(ctx)
if err == nil {
t.Error("drainRequests() should timeout with error")
}
// Cleanup
srv.inFlightRequests.Add(-5)
}
func TestGetClientIP(t *testing.T) {
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
// Including here for completeness of server tests
}

190
pkg/server/tls.go Normal file
View File

@@ -0,0 +1,190 @@
package server
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"path/filepath"
"time"
"golang.org/x/crypto/acme/autocert"
)
// generateSelfSignedCert generates a self-signed certificate for the given host.
// Returns the certificate and private key in PEM format.
func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) {
// Generate private key
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}
// Create certificate template
notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"ResolveSpec Self-Signed"},
CommonName: host,
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
// Add host as DNS name or IP address
if ip := net.ParseIP(host); ip != nil {
template.IPAddresses = []net.IP{ip}
} else {
template.DNSNames = []string{host}
}
// Create certificate
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
}
// Encode certificate to PEM
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
// Encode private key to PEM
privBytes, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
return certPEM, keyPEM, nil
}
// saveCertToTempFiles saves certificate and key PEM data to temporary files.
// Returns the file paths for the certificate and key.
func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile string, err error) {
// Create temporary directory
tmpDir, err := os.MkdirTemp("", "resolvespec-certs-*")
if err != nil {
return "", "", fmt.Errorf("failed to create temp directory: %w", err)
}
certFile = filepath.Join(tmpDir, "cert.pem")
keyFile = filepath.Join(tmpDir, "key.pem")
// Write certificate
if err := os.WriteFile(certFile, certPEM, 0600); err != nil {
os.RemoveAll(tmpDir)
return "", "", fmt.Errorf("failed to write certificate: %w", err)
}
// Write key
if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil {
os.RemoveAll(tmpDir)
return "", "", fmt.Errorf("failed to write private key: %w", err)
}
return certFile, keyFile, nil
}
// setupAutoTLS configures automatic TLS certificate management using Let's Encrypt.
// Returns a TLS config that can be used with http.Server.
func setupAutoTLS(domains []string, email, cacheDir string) (*tls.Config, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("at least one domain must be specified for AutoTLS")
}
// Set default cache directory
if cacheDir == "" {
cacheDir = "./certs-cache"
}
// Create cache directory if it doesn't exist
if err := os.MkdirAll(cacheDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create certificate cache directory: %w", err)
}
// Create autocert manager
m := &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(cacheDir),
HostPolicy: autocert.HostWhitelist(domains...),
Email: email,
}
// Create TLS config
tlsConfig := m.TLSConfig()
tlsConfig.MinVersion = tls.VersionTLS12
return tlsConfig, nil
}
// configureTLS configures TLS for the server based on the provided configuration.
// Returns the TLS config and certificate/key file paths (if applicable).
func configureTLS(cfg Config) (*tls.Config, string, string, error) {
// Option 1: Certificate files provided
if cfg.SSLCert != "" && cfg.SSLKey != "" {
// Validate that files exist
if _, err := os.Stat(cfg.SSLCert); os.IsNotExist(err) {
return nil, "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert)
}
if _, err := os.Stat(cfg.SSLKey); os.IsNotExist(err) {
return nil, "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey)
}
// Return basic TLS config - cert/key will be loaded by ListenAndServeTLS
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
return tlsConfig, cfg.SSLCert, cfg.SSLKey, nil
}
// Option 2: Auto TLS (Let's Encrypt)
if cfg.AutoTLS {
tlsConfig, err := setupAutoTLS(cfg.AutoTLSDomains, cfg.AutoTLSEmail, cfg.AutoTLSCacheDir)
if err != nil {
return nil, "", "", fmt.Errorf("failed to setup AutoTLS: %w", err)
}
return tlsConfig, "", "", nil
}
// Option 3: Self-signed certificate
if cfg.SelfSignedSSL {
host := cfg.Host
if host == "" || host == "0.0.0.0" {
host = "localhost"
}
certPEM, keyPEM, err := generateSelfSignedCert(host)
if err != nil {
return nil, "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err)
}
certFile, keyFile, err := saveCertToTempFiles(certPEM, keyPEM)
if err != nil {
return nil, "", "", fmt.Errorf("failed to save self-signed certificate: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
return tlsConfig, certFile, keyFile, nil
}
return nil, "", "", nil
}