diff --git a/go.mod b/go.mod index 3e97432..2033ea6 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.3 github.com/jackc/pgx/v5 v5.6.0 + github.com/klauspost/compress v1.18.0 github.com/prometheus/client_golang v1.23.2 github.com/redis/go-redis/v9 v9.17.1 github.com/spf13/viper v1.21.0 @@ -30,6 +31,7 @@ require ( go.opentelemetry.io/otel/sdk v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 go.uber.org/zap v1.27.0 + golang.org/x/crypto v0.43.0 golang.org/x/time v0.14.0 gorm.io/driver/postgres v1.6.0 gorm.io/gorm v1.30.0 @@ -70,7 +72,6 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/klauspost/compress v1.18.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -122,7 +123,6 @@ require ( go.uber.org/multierr v1.10.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/crypto v0.43.0 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/net v0.45.0 // indirect golang.org/x/sync v0.18.0 // indirect diff --git a/pkg/errortracking/README.md b/pkg/errortracking/README.md index a9950c2..89436ff 100644 --- a/pkg/errortracking/README.md +++ b/pkg/errortracking/README.md @@ -90,12 +90,12 @@ Panics are automatically captured when using the logger's panic handlers: ```go // Using CatchPanic -defer logger.CatchPanic("MyFunction") +defer logger.CatchPanic("MyFunction")() // Using CatchPanicCallback defer logger.CatchPanicCallback("MyFunction", func(err any) { // Custom cleanup -}) +})() // Using HandlePanic defer func() { diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index a033b1b..d1c7705 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -75,6 +75,28 @@ func CloseErrorTracking() error { return nil } +// extractContext attempts to find a context.Context in the given arguments. +// It returns the found context (or context.Background() if not found) and +// the remaining arguments without the context. +func extractContext(args ...interface{}) (context.Context, []interface{}) { + ctx := context.Background() + var newArgs []interface{} + found := false + + for _, arg := range args { + if c, ok := arg.(context.Context); ok { + if !found { + ctx = c + found = true + } + // Ignore any additional context.Context arguments after the first one. + continue + } + newArgs = append(newArgs, arg) + } + return ctx, newArgs +} + func Info(template string, args ...interface{}) { if Logger == nil { log.Printf(template, args...) @@ -84,7 +106,8 @@ func Info(template string, args ...interface{}) { } func Warn(template string, args ...interface{}) { - message := fmt.Sprintf(template, args...) + ctx, remainingArgs := extractContext(args...) + message := fmt.Sprintf(template, remainingArgs...) if Logger == nil { log.Printf("%s", message) } else { @@ -93,14 +116,15 @@ func Warn(template string, args ...interface{}) { // Send to error tracker if errorTracker != nil { - errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{ + errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{ "process_id": os.Getpid(), }) } } func Error(template string, args ...interface{}) { - message := fmt.Sprintf(template, args...) + ctx, remainingArgs := extractContext(args...) + message := fmt.Sprintf(template, remainingArgs...) if Logger == nil { log.Printf("%s", message) } else { @@ -109,7 +133,7 @@ func Error(template string, args ...interface{}) { // Send to error tracker if errorTracker != nil { - errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{ + errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{ "process_id": os.Getpid(), }) } @@ -124,34 +148,41 @@ func Debug(template string, args ...interface{}) { } // CatchPanic - Handle panic -func CatchPanicCallback(location string, cb func(err any)) { - if err := recover(); err != nil { - callstack := debug.Stack() +// Returns a function that should be deferred to catch panics +// Example usage: defer CatchPanicCallback("MyFunction", func(err any) { /* cleanup */ })() +func CatchPanicCallback(location string, cb func(err any), args ...interface{}) func() { + ctx, _ := extractContext(args...) + return func() { + if err := recover(); err != nil { + callstack := debug.Stack() - if Logger != nil { - Error("Panic in %s : %v", location, err) - } else { - fmt.Printf("%s:PANIC->%+v", location, err) - debug.PrintStack() - } + if Logger != nil { + Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly + } else { + fmt.Printf("%s:PANIC->%+v", location, err) + debug.PrintStack() + } - // Send to error tracker - if errorTracker != nil { - errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{ - "location": location, - "process_id": os.Getpid(), - }) - } + // Send to error tracker + if errorTracker != nil { + errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{ + "location": location, + "process_id": os.Getpid(), + }) + } - if cb != nil { - cb(err) + if cb != nil { + cb(err) + } } } } // CatchPanic - Handle panic -func CatchPanic(location string) { - CatchPanicCallback(location, nil) +// Returns a function that should be deferred to catch panics +// Example usage: defer CatchPanic("MyFunction")() +func CatchPanic(location string, args ...interface{}) func() { + return CatchPanicCallback(location, nil, args...) } // HandlePanic logs a panic and returns it as an error @@ -163,13 +194,14 @@ func CatchPanic(location string) { // err = logger.HandlePanic("MethodName", r) // } // }() -func HandlePanic(methodName string, r any) error { +func HandlePanic(methodName string, r any, args ...interface{}) error { + ctx, _ := extractContext(args...) stack := debug.Stack() - Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) + Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly // Send to error tracker if errorTracker != nil { - errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{ + errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{ "method": methodName, "process_id": os.Getpid(), }) diff --git a/pkg/metrics/interfaces.go b/pkg/metrics/interfaces.go index 9d11586..8b35463 100644 --- a/pkg/metrics/interfaces.go +++ b/pkg/metrics/interfaces.go @@ -39,6 +39,9 @@ type Provider interface { // UpdateEventQueueSize updates the event queue size metric UpdateEventQueueSize(size int64) + // RecordPanic records a panic event + RecordPanic(methodName string) + // Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint) Handler() http.Handler } @@ -75,6 +78,7 @@ func (n *NoOpProvider) RecordEventPublished(source, eventType string) {} func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) { } func (n *NoOpProvider) UpdateEventQueueSize(size int64) {} +func (n *NoOpProvider) RecordPanic(methodName string) {} func (n *NoOpProvider) Handler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) diff --git a/pkg/metrics/prometheus.go b/pkg/metrics/prometheus.go index 49c6ebb..9761a68 100644 --- a/pkg/metrics/prometheus.go +++ b/pkg/metrics/prometheus.go @@ -20,6 +20,7 @@ type PrometheusProvider struct { cacheHits *prometheus.CounterVec cacheMisses *prometheus.CounterVec cacheSize *prometheus.GaugeVec + panicsTotal *prometheus.CounterVec } // NewPrometheusProvider creates a new Prometheus metrics provider @@ -83,6 +84,13 @@ func NewPrometheusProvider() *PrometheusProvider { }, []string{"provider"}, ), + panicsTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "panics_total", + Help: "Total number of panics", + }, + []string{"method"}, + ), } } @@ -145,6 +153,11 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) { p.cacheSize.WithLabelValues(provider).Set(float64(size)) } +// RecordPanic implements the Provider interface +func (p *PrometheusProvider) RecordPanic(methodName string) { + p.panicsTotal.WithLabelValues(methodName).Inc() +} + // Handler implements Provider interface func (p *PrometheusProvider) Handler() http.Handler { return promhttp.Handler() diff --git a/pkg/middleware/panic.go b/pkg/middleware/panic.go new file mode 100644 index 0000000..fa36086 --- /dev/null +++ b/pkg/middleware/panic.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "net/http" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/metrics" +) + +const panicMiddlewareMethodName = "PanicMiddleware" + +// PanicRecovery is a middleware that recovers from panics, logs the error, +// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error. +func PanicRecovery(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rcv := recover(); rcv != nil { + // Record the panic metric + metrics.GetProvider().RecordPanic(panicMiddlewareMethodName) + + // Log the panic and send to error tracker + // We pass the request context so the error tracker can potentially + // link the panic to the request trace. + ctx := r.Context() + err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx) + + // Respond with a 500 error + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/middleware/panic_test.go b/pkg/middleware/panic_test.go new file mode 100644 index 0000000..e5cb77e --- /dev/null +++ b/pkg/middleware/panic_test.go @@ -0,0 +1,86 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/metrics" + "github.com/stretchr/testify/assert" +) + +// mockMetricsProvider is a mock for the metrics provider to check if methods are called. +type mockMetricsProvider struct { + metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods + panicRecorded bool + methodName string +} + +func (m *mockMetricsProvider) RecordPanic(methodName string) { + m.panicRecorded = true + m.methodName = methodName +} + +func TestPanicRecovery(t *testing.T) { + // Initialize a mock logger to avoid actual logging output during tests + logger.Init(true) + + // Setup mock metrics provider + mockProvider := &mockMetricsProvider{} + originalProvider := metrics.GetProvider() + metrics.SetProvider(mockProvider) + defer metrics.SetProvider(originalProvider) // Restore original provider after test + + // 1. Test case: A handler that panics + t.Run("recovers from panic and returns 500", func(t *testing.T) { + // Reset mock state for this sub-test + mockProvider.panicRecorded = false + mockProvider.methodName = "" + + panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("something went terribly wrong") + }) + + // Create the middleware wrapping the panicking handler + testHandler := PanicRecovery(panicHandler) + + // Create a test request and response recorder + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + // Serve the request + testHandler.ServeHTTP(rr, req) + + // Assertions + assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500") + assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body") + + // Assert that the metric was recorded + assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider") + assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name") + }) + + // 2. Test case: A handler that does NOT panic + t.Run("does not interfere with a non-panicking handler", func(t *testing.T) { + // Reset mock state for this sub-test + mockProvider.panicRecorded = false + + successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + testHandler := PanicRecovery(successHandler) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + testHandler.ServeHTTP(rr, req) + + // Assertions + assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200") + assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body") + assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic") + }) +} diff --git a/pkg/security/provider.go b/pkg/security/provider.go index 2eb7b4a..7dcf609 100644 --- a/pkg/security/provider.go +++ b/pkg/security/provider.go @@ -296,7 +296,7 @@ func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName } func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (reflect.Value, error) { - defer logger.CatchPanic("ApplyColumnSecurity") + defer logger.CatchPanic("ApplyColumnSecurity")() if m.ColumnSecurity == nil { return records, fmt.Errorf("security not initialized") @@ -437,7 +437,7 @@ func (m *SecurityList) LoadRowSecurity(ctx context.Context, pUserID int, pSchema } func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) { - defer logger.CatchPanic("GetRowSecurityTemplate") + defer logger.CatchPanic("GetRowSecurityTemplate")() if m.RowSecurity == nil { return RowSecurity{}, fmt.Errorf("security not initialized") diff --git a/pkg/server/README.md b/pkg/server/README.md index 45342c9..6cd46b3 100644 --- a/pkg/server/README.md +++ b/pkg/server/README.md @@ -1,233 +1,314 @@ # Server Package -Graceful HTTP server with request draining and shutdown coordination. +Production-ready HTTP server manager with graceful shutdown, request draining, and comprehensive TLS/HTTPS support. + +## Features + +✅ **Multiple Server Management** - Run multiple HTTP/HTTPS servers concurrently +✅ **Graceful Shutdown** - Handles SIGINT/SIGTERM with request draining +✅ **Automatic Request Rejection** - New requests get 503 during shutdown +✅ **Health & Readiness Endpoints** - Kubernetes-ready health checks +✅ **Shutdown Callbacks** - Register cleanup functions (DB, cache, metrics) +✅ **Comprehensive TLS Support**: + - Certificate files (production) + - Self-signed certificates (development/testing) + - Let's Encrypt / AutoTLS (automatic certificate management) +✅ **GZIP Compression** - Optional response compression +✅ **Panic Recovery** - Automatic panic recovery middleware +✅ **Configurable Timeouts** - Read, write, idle, drain, and shutdown timeouts ## Quick Start +### Single Server + ```go import "github.com/bitechdev/ResolveSpec/pkg/server" -// Create server -srv := server.NewGracefulServer(server.Config{ - Addr: ":8080", - Handler: router, +// Create server manager +mgr := server.NewManager() + +// Add server +_, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: myRouter, + GZIP: true, }) -// Start server (blocks until shutdown signal) -if err := srv.ListenAndServe(); err != nil { +// Start and wait for shutdown signal +if err := mgr.ServeWithGracefulShutdown(); err != nil { log.Fatal(err) } ``` -## Features +### Multiple Servers -✅ Graceful shutdown on SIGINT/SIGTERM -✅ Request draining (waits for in-flight requests) -✅ Automatic request rejection during shutdown -✅ Health and readiness endpoints -✅ Shutdown callbacks for cleanup -✅ Configurable timeouts +```go +mgr := server.NewManager() + +// Public API +mgr.Add(server.Config{ + Name: "public-api", + Port: 8080, + Handler: publicRouter, +}) + +// Admin API +mgr.Add(server.Config{ + Name: "admin-api", + Port: 8081, + Handler: adminRouter, +}) + +// Start all and wait +mgr.ServeWithGracefulShutdown() +``` + +## HTTPS/TLS Configuration + +### Option 1: Certificate Files (Production) + +```go +mgr.Add(server.Config{ + Name: "https-server", + Host: "0.0.0.0", + Port: 443, + Handler: handler, + SSLCert: "/etc/ssl/certs/server.crt", + SSLKey: "/etc/ssl/private/server.key", +}) +``` + +### Option 2: Self-Signed Certificate (Development) + +```go +mgr.Add(server.Config{ + Name: "dev-server", + Host: "localhost", + Port: 8443, + Handler: handler, + SelfSignedSSL: true, // Auto-generates certificate +}) +``` + +### Option 3: Let's Encrypt / AutoTLS (Production) + +```go +mgr.Add(server.Config{ + Name: "prod-server", + Host: "0.0.0.0", + Port: 443, + Handler: handler, + AutoTLS: true, + AutoTLSDomains: []string{"example.com", "www.example.com"}, + AutoTLSEmail: "admin@example.com", + AutoTLSCacheDir: "./certs-cache", // Certificate cache directory +}) +``` ## Configuration ```go -config := server.Config{ - // Server address - Addr: ":8080", +server.Config{ + // Basic configuration + Name: "my-server", // Server name (required) + Host: "0.0.0.0", // Bind address + Port: 8080, // Port (required) + Handler: myRouter, // HTTP handler (required) + Description: "My API server", // Optional description - // HTTP handler - Handler: myRouter, + // Features + GZIP: true, // Enable GZIP compression - // Maximum time for graceful shutdown (default: 30s) - ShutdownTimeout: 30 * time.Second, + // TLS/HTTPS (choose one option) + SSLCert: "/path/to/cert.pem", // Certificate file + SSLKey: "/path/to/key.pem", // Key file + SelfSignedSSL: false, // Auto-generate self-signed cert + AutoTLS: false, // Let's Encrypt + AutoTLSDomains: []string{}, // Domains for AutoTLS + AutoTLSEmail: "", // Email for Let's Encrypt + AutoTLSCacheDir: "./certs-cache", // Cert cache directory - // Time to wait for in-flight requests (default: 25s) - DrainTimeout: 25 * time.Second, - - // Request read timeout (default: 10s) - ReadTimeout: 10 * time.Second, - - // Response write timeout (default: 10s) - WriteTimeout: 10 * time.Second, - - // Idle connection timeout (default: 120s) - IdleTimeout: 120 * time.Second, + // Timeouts + ShutdownTimeout: 30 * time.Second, // Max shutdown time + DrainTimeout: 25 * time.Second, // Request drain timeout + ReadTimeout: 15 * time.Second, // Request read timeout + WriteTimeout: 15 * time.Second, // Response write timeout + IdleTimeout: 60 * time.Second, // Idle connection timeout } - -srv := server.NewGracefulServer(config) ``` -## Shutdown Behavior +## Graceful Shutdown -**Signal received (SIGINT/SIGTERM):** +### Automatic (Recommended) -1. **Mark as shutting down** - New requests get 503 -2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests -3. **Shutdown server** - Close listeners and connections -4. **Execute callbacks** - Run registered cleanup functions +```go +mgr := server.NewManager() +// Add servers... + +// This blocks until SIGINT/SIGTERM +mgr.ServeWithGracefulShutdown() ``` -Time Event -───────────────────────────────────────── -0s Signal received: SIGTERM - ├─ Mark as shutting down - ├─ Reject new requests (503) - └─ Start draining... -1s In-flight: 50 requests -2s In-flight: 32 requests -3s In-flight: 12 requests -4s In-flight: 3 requests -5s In-flight: 0 requests ✓ - └─ All requests drained +### Manual Control -5s Execute shutdown callbacks -6s Shutdown complete +```go +mgr := server.NewManager() + +// Add and start servers +mgr.StartAll() + +// Later... stop gracefully +ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +defer cancel() + +if err := mgr.StopAllWithContext(ctx); err != nil { + log.Printf("Shutdown error: %v", err) +} +``` + +### Shutdown Callbacks + +Register cleanup functions to run during shutdown: + +```go +// Close database +mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Closing database...") + return db.Close() +}) + +// Flush metrics +mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Flushing metrics...") + return metrics.Flush(ctx) +}) + +// Close cache +mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Closing cache...") + return cache.Close() +}) ``` ## Health Checks -### Health Endpoint - -Returns 200 when healthy, 503 when shutting down: +### Adding Health Endpoints ```go -router.HandleFunc("/health", srv.HealthCheckHandler()) +instance, _ := mgr.Add(server.Config{ + Name: "api-server", + Port: 8080, + Handler: router, +}) + +// Add health endpoints to your router +router.HandleFunc("/health", instance.HealthCheckHandler()) +router.HandleFunc("/ready", instance.ReadinessHandler()) ``` -**Response (healthy):** +### Health Endpoint + +Returns server health status: + +**Healthy (200 OK):** ```json {"status":"healthy"} ``` -**Response (shutting down):** +**Shutting Down (503 Service Unavailable):** ```json {"status":"shutting_down"} ``` ### Readiness Endpoint -Includes in-flight request count: +Returns readiness with in-flight request count: -```go -router.HandleFunc("/ready", srv.ReadinessHandler()) -``` - -**Response:** +**Ready (200 OK):** ```json {"ready":true,"in_flight_requests":12} ``` -**During shutdown:** +**Not Ready (503 Service Unavailable):** ```json {"ready":false,"reason":"shutting_down"} ``` -## Shutdown Callbacks +## Shutdown Behavior -Register cleanup functions to run during shutdown: +When a shutdown signal (SIGINT/SIGTERM) is received: -```go -// Close database -server.RegisterShutdownCallback(func(ctx context.Context) error { - logger.Info("Closing database connection...") - return db.Close() -}) +1. **Mark as shutting down** → New requests get 503 +2. **Execute callbacks** → Run cleanup functions +3. **Drain requests** → Wait up to `DrainTimeout` for in-flight requests +4. **Shutdown servers** → Close listeners and connections -// Flush metrics -server.RegisterShutdownCallback(func(ctx context.Context) error { - logger.Info("Flushing metrics...") - return metricsProvider.Flush(ctx) -}) +``` +Time Event +───────────────────────────────────────── +0s Signal received: SIGTERM + ├─ Mark servers as shutting down + ├─ Reject new requests (503) + └─ Execute shutdown callbacks -// Close cache -server.RegisterShutdownCallback(func(ctx context.Context) error { - logger.Info("Closing cache...") - return cache.Close() -}) +1s Callbacks complete + └─ Start draining requests... + +2s In-flight: 50 requests +3s In-flight: 32 requests +4s In-flight: 12 requests +5s In-flight: 3 requests +6s In-flight: 0 requests ✓ + └─ All requests drained + +6s Shutdown servers +7s All servers stopped ✓ ``` -## Complete Example +## Server Management + +### Get Server Instance ```go -package main - -import ( - "context" - "log" - "net/http" - "time" - - "github.com/bitechdev/ResolveSpec/pkg/middleware" - "github.com/bitechdev/ResolveSpec/pkg/metrics" - "github.com/bitechdev/ResolveSpec/pkg/server" - "github.com/gorilla/mux" -) - -func main() { - // Initialize metrics - metricsProvider := metrics.NewPrometheusProvider() - metrics.SetProvider(metricsProvider) - - // Create router - router := mux.NewRouter() - - // Apply middleware - rateLimiter := middleware.NewRateLimiter(100, 20) - sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB) - sanitizer := middleware.DefaultSanitizer() - - router.Use(rateLimiter.Middleware) - router.Use(sizeLimiter.Middleware) - router.Use(sanitizer.Middleware) - router.Use(metricsProvider.Middleware) - - // API routes - router.HandleFunc("/api/data", dataHandler) - - // Create graceful server - srv := server.NewGracefulServer(server.Config{ - Addr: ":8080", - Handler: router, - ShutdownTimeout: 30 * time.Second, - DrainTimeout: 25 * time.Second, - }) - - // Health checks - router.HandleFunc("/health", srv.HealthCheckHandler()) - router.HandleFunc("/ready", srv.ReadinessHandler()) - - // Metrics endpoint - router.Handle("/metrics", metricsProvider.Handler()) - - // Register shutdown callbacks - server.RegisterShutdownCallback(func(ctx context.Context) error { - log.Println("Cleanup: Flushing metrics...") - return nil - }) - - server.RegisterShutdownCallback(func(ctx context.Context) error { - log.Println("Cleanup: Closing database...") - // return db.Close() - return nil - }) - - // Start server (blocks until shutdown) - log.Printf("Starting server on :8080") - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } - - // Wait for shutdown to complete - srv.Wait() - log.Println("Server stopped") +instance, err := mgr.Get("api-server") +if err != nil { + log.Fatal(err) } -func dataHandler(w http.ResponseWriter, r *http.Request) { - // Your handler logic - time.Sleep(100 * time.Millisecond) // Simulate work - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"message":"success"}`)) +// Check status +fmt.Printf("Address: %s\n", instance.Addr()) +fmt.Printf("Name: %s\n", instance.Name()) +fmt.Printf("In-flight: %d\n", instance.InFlightRequests()) +fmt.Printf("Shutting down: %v\n", instance.IsShuttingDown()) +``` + +### List All Servers + +```go +instances := mgr.List() +for _, instance := range instances { + fmt.Printf("Server: %s at %s\n", instance.Name(), instance.Addr()) +} +``` + +### Remove Server + +```go +// Stop and remove a server +if err := mgr.Remove("api-server"); err != nil { + log.Printf("Error removing server: %v", err) +} +``` + +### Restart All Servers + +```go +// Gracefully restart all servers +if err := mgr.RestartAll(); err != nil { + log.Printf("Error restarting: %v", err) } ``` @@ -250,23 +331,21 @@ spec: ports: - containerPort: 8080 - # Liveness probe - is app running? + # Liveness probe livenessProbe: httpGet: path: /health port: 8080 initialDelaySeconds: 10 periodSeconds: 10 - timeoutSeconds: 5 - # Readiness probe - can app handle traffic? + # Readiness probe readinessProbe: httpGet: path: /ready port: 8080 initialDelaySeconds: 5 periodSeconds: 5 - timeoutSeconds: 3 # Graceful shutdown lifecycle: @@ -274,26 +353,12 @@ spec: exec: command: ["/bin/sh", "-c", "sleep 5"] - # Environment env: - name: SHUTDOWN_TIMEOUT value: "30" -``` -### Service - -```yaml -apiVersion: v1 -kind: Service -metadata: - name: myapp -spec: - selector: - app: myapp - ports: - - port: 80 - targetPort: 8080 - type: LoadBalancer + # Allow time for graceful shutdown + terminationGracePeriodSeconds: 35 ``` ## Docker Compose @@ -312,8 +377,70 @@ services: interval: 10s timeout: 5s retries: 3 - start_period: 10s - stop_grace_period: 35s # Slightly longer than shutdown timeout + stop_grace_period: 35s +``` + +## Complete Example + +```go +package main + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/server" +) + +func main() { + // Create server manager + mgr := server.NewManager() + + // Register shutdown callbacks + mgr.RegisterShutdownCallback(func(ctx context.Context) error { + log.Println("Cleanup: Closing database...") + // return db.Close() + return nil + }) + + // Create router + router := http.NewServeMux() + router.HandleFunc("/api/data", dataHandler) + + // Add server + instance, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "0.0.0.0", + Port: 8080, + Handler: router, + GZIP: true, + ShutdownTimeout: 30 * time.Second, + DrainTimeout: 25 * time.Second, + }) + if err != nil { + log.Fatal(err) + } + + // Add health endpoints + router.HandleFunc("/health", instance.HealthCheckHandler()) + router.HandleFunc("/ready", instance.ReadinessHandler()) + + // Start and wait for shutdown + log.Println("Starting server on :8080") + if err := mgr.ServeWithGracefulShutdown(); err != nil { + log.Printf("Server stopped: %v", err) + } + + log.Println("Server shutdown complete") +} + +func dataHandler(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) // Simulate work + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"message":"success"}`)) +} ``` ## Testing Graceful Shutdown @@ -330,7 +457,7 @@ SERVER_PID=$! # Wait for server to start sleep 2 -# Send some requests +# Send requests for i in {1..10}; do curl http://localhost:8080/api/data & done @@ -341,7 +468,7 @@ sleep 1 # Send shutdown signal kill -TERM $SERVER_PID -# Try to send more requests (should get 503) +# Try more requests (should get 503) curl -v http://localhost:8080/api/data # Wait for server to stop @@ -349,101 +476,13 @@ wait $SERVER_PID echo "Server stopped gracefully" ``` -### Expected Output - -``` -Starting server on :8080 -Received signal: terminated, initiating graceful shutdown -Starting graceful shutdown... -Waiting for 8 in-flight requests to complete... -Waiting for 4 in-flight requests to complete... -Waiting for 1 in-flight requests to complete... -All requests drained in 2.3s -Cleanup: Flushing metrics... -Cleanup: Closing database... -Shutting down HTTP server... -Graceful shutdown complete -Server stopped -``` - -## Monitoring In-Flight Requests - -```go -// Get current in-flight count -count := srv.InFlightRequests() -fmt.Printf("In-flight requests: %d\n", count) - -// Check if shutting down -if srv.IsShuttingDown() { - fmt.Println("Server is shutting down") -} -``` - -## Advanced Usage - -### Custom Shutdown Logic - -```go -// Implement custom shutdown -go func() { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - - <-sigChan - log.Println("Shutdown signal received") - - // Custom pre-shutdown logic - log.Println("Running custom cleanup...") - - // Shutdown with callbacks - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := srv.ShutdownWithCallbacks(ctx); err != nil { - log.Printf("Shutdown error: %v", err) - } -}() - -// Start server -srv.server.ListenAndServe() -``` - -### Multiple Servers - -```go -// HTTP server -httpSrv := server.NewGracefulServer(server.Config{ - Addr: ":8080", - Handler: httpRouter, -}) - -// HTTPS server -httpsSrv := server.NewGracefulServer(server.Config{ - Addr: ":8443", - Handler: httpsRouter, -}) - -// Start both -go httpSrv.ListenAndServe() -go httpsSrv.ListenAndServe() - -// Shutdown both on signal -sigChan := make(chan os.Signal, 1) -signal.Notify(sigChan, os.Interrupt) -<-sigChan - -ctx := context.Background() -httpSrv.Shutdown(ctx) -httpsSrv.Shutdown(ctx) -``` - ## Best Practices 1. **Set appropriate timeouts** - `DrainTimeout` < `ShutdownTimeout` - `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds` -2. **Register cleanup callbacks** for: +2. **Use shutdown callbacks** for: - Database connections - Message queues - Metrics flushing @@ -458,7 +497,12 @@ httpsSrv.Shutdown(ctx) - Set `preStop` hook in Kubernetes (5-10s delay) - Allows load balancer to deregister before shutdown -5. **Monitoring** +5. **HTTPS in production** + - Use AutoTLS for public-facing services + - Use certificate files for enterprise PKI + - Use self-signed only for development/testing + +6. **Monitoring** - Track in-flight requests in metrics - Alert on slow drains - Monitor shutdown duration @@ -470,24 +514,63 @@ httpsSrv.Shutdown(ctx) ```go // Increase drain timeout config.DrainTimeout = 60 * time.Second +config.ShutdownTimeout = 65 * time.Second ``` -### Requests Still Timing Out +### Requests Timing Out ```go // Increase write timeout config.WriteTimeout = 30 * time.Second ``` -### Force Shutdown Not Working - -The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed. - -### Debugging Shutdown +### Certificate Issues + +```go +// Verify certificate files exist and are readable +if _, err := os.Stat(config.SSLCert); err != nil { + log.Fatalf("Certificate not found: %v", err) +} + +// For AutoTLS, ensure: +// - Port 443 is accessible +// - Domains resolve to server IP +// - Cache directory is writable +``` + +### Debug Logging ```go -// Enable debug logging import "github.com/bitechdev/ResolveSpec/pkg/logger" +// Enable debug logging logger.SetLevel("debug") ``` + +## API Reference + +### Manager Methods + +- `NewManager()` - Create new server manager +- `Add(cfg Config)` - Register server instance +- `Get(name string)` - Get server by name +- `Remove(name string)` - Stop and remove server +- `StartAll()` - Start all registered servers +- `StopAll()` - Stop all servers gracefully +- `StopAllWithContext(ctx)` - Stop with timeout +- `RestartAll()` - Restart all servers +- `List()` - Get all server instances +- `ServeWithGracefulShutdown()` - Start and block until shutdown +- `RegisterShutdownCallback(cb)` - Register cleanup function + +### Instance Methods + +- `Start()` - Start the server +- `Stop(ctx)` - Stop gracefully +- `Addr()` - Get server address +- `Name()` - Get server name +- `HealthCheckHandler()` - Get health handler +- `ReadinessHandler()` - Get readiness handler +- `InFlightRequests()` - Get in-flight count +- `IsShuttingDown()` - Check shutdown status +- `Wait()` - Block until shutdown complete diff --git a/pkg/server/example_test.go b/pkg/server/example_test.go new file mode 100644 index 0000000..032b6f0 --- /dev/null +++ b/pkg/server/example_test.go @@ -0,0 +1,294 @@ +package server_test + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/server" +) + +// ExampleManager_basic demonstrates basic server manager usage +func ExampleManager_basic() { + // Create a server manager + mgr := server.NewManager() + + // Define a simple handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "Hello from server!") + }) + + // Add an HTTP server + _, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: handler, + GZIP: true, // Enable GZIP compression + }) + if err != nil { + panic(err) + } + + // Start all servers + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Server is now running... + // When done, stop gracefully + if err := mgr.StopAll(); err != nil { + panic(err) + } +} + +// ExampleManager_https demonstrates HTTPS configurations +func ExampleManager_https() { + mgr := server.NewManager() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Secure connection!") + }) + + // Option 1: Use certificate files + _, err := mgr.Add(server.Config{ + Name: "https-server-files", + Host: "localhost", + Port: 8443, + Handler: handler, + SSLCert: "/path/to/cert.pem", + SSLKey: "/path/to/key.pem", + }) + if err != nil { + panic(err) + } + + // Option 2: Self-signed certificate (for development) + _, err = mgr.Add(server.Config{ + Name: "https-server-self-signed", + Host: "localhost", + Port: 8444, + Handler: handler, + SelfSignedSSL: true, + }) + if err != nil { + panic(err) + } + + // Option 3: Let's Encrypt / AutoTLS (for production) + _, err = mgr.Add(server.Config{ + Name: "https-server-letsencrypt", + Host: "0.0.0.0", + Port: 443, + Handler: handler, + AutoTLS: true, + AutoTLSDomains: []string{"example.com", "www.example.com"}, + AutoTLSEmail: "admin@example.com", + AutoTLSCacheDir: "./certs-cache", + }) + if err != nil { + panic(err) + } + + // Start all servers + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Cleanup + mgr.StopAll() +} + +// ExampleManager_gracefulShutdown demonstrates graceful shutdown with callbacks +func ExampleManager_gracefulShutdown() { + mgr := server.NewManager() + + // Register shutdown callbacks for cleanup tasks + mgr.RegisterShutdownCallback(func(ctx context.Context) error { + fmt.Println("Closing database connections...") + // Close your database here + return nil + }) + + mgr.RegisterShutdownCallback(func(ctx context.Context) error { + fmt.Println("Flushing metrics...") + // Flush metrics here + return nil + }) + + // Add server with custom timeouts + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate some work + time.Sleep(100 * time.Millisecond) + fmt.Fprintln(w, "Done!") + }) + + _, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: handler, + ShutdownTimeout: 30 * time.Second, // Max time for shutdown + DrainTimeout: 25 * time.Second, // Time to wait for in-flight requests + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + }) + if err != nil { + panic(err) + } + + // Start servers and block until shutdown signal (SIGINT/SIGTERM) + // This will automatically handle graceful shutdown with callbacks + if err := mgr.ServeWithGracefulShutdown(); err != nil { + fmt.Printf("Shutdown completed: %v\n", err) + } +} + +// ExampleManager_healthChecks demonstrates health and readiness endpoints +func ExampleManager_healthChecks() { + mgr := server.NewManager() + + // Create a router with health endpoints + mux := http.NewServeMux() + mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Data endpoint") + }) + + // Add server + instance, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: mux, + }) + if err != nil { + panic(err) + } + + // Add health and readiness endpoints + mux.HandleFunc("/health", instance.HealthCheckHandler()) + mux.HandleFunc("/ready", instance.ReadinessHandler()) + + // Start the server + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Health check returns: + // - 200 OK with {"status":"healthy"} when healthy + // - 503 Service Unavailable with {"status":"shutting_down"} when shutting down + + // Readiness check returns: + // - 200 OK with {"ready":true,"in_flight_requests":N} when ready + // - 503 Service Unavailable with {"ready":false,"reason":"shutting_down"} when shutting down + + // Cleanup + mgr.StopAll() +} + +// ExampleManager_multipleServers demonstrates running multiple servers +func ExampleManager_multipleServers() { + mgr := server.NewManager() + + // Public API server + publicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Public API") + }) + _, err := mgr.Add(server.Config{ + Name: "public-api", + Host: "0.0.0.0", + Port: 8080, + Handler: publicHandler, + GZIP: true, + }) + if err != nil { + panic(err) + } + + // Admin API server (different port) + adminHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Admin API") + }) + _, err = mgr.Add(server.Config{ + Name: "admin-api", + Host: "localhost", + Port: 8081, + Handler: adminHandler, + }) + if err != nil { + panic(err) + } + + // Metrics server (internal only) + metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Metrics data") + }) + _, err = mgr.Add(server.Config{ + Name: "metrics", + Host: "127.0.0.1", + Port: 9090, + Handler: metricsHandler, + }) + if err != nil { + panic(err) + } + + // Start all servers at once + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Get specific server instance + publicInstance, err := mgr.Get("public-api") + if err != nil { + panic(err) + } + fmt.Printf("Public API running on: %s\n", publicInstance.Addr()) + + // List all servers + instances := mgr.List() + fmt.Printf("Running %d servers\n", len(instances)) + + // Stop all servers gracefully (in parallel) + if err := mgr.StopAll(); err != nil { + panic(err) + } +} + +// ExampleManager_monitoring demonstrates monitoring server state +func ExampleManager_monitoring() { + mgr := server.NewManager() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) // Simulate work + fmt.Fprintln(w, "Done") + }) + + instance, err := mgr.Add(server.Config{ + Name: "api-server", + Host: "localhost", + Port: 8080, + Handler: handler, + }) + if err != nil { + panic(err) + } + + if err := mgr.StartAll(); err != nil { + panic(err) + } + + // Check server status + fmt.Printf("Server address: %s\n", instance.Addr()) + fmt.Printf("Server name: %s\n", instance.Name()) + fmt.Printf("Is shutting down: %v\n", instance.IsShuttingDown()) + fmt.Printf("In-flight requests: %d\n", instance.InFlightRequests()) + + // Cleanup + mgr.StopAll() + + // Wait for complete shutdown + instance.Wait() +} diff --git a/pkg/server/interfaces.go b/pkg/server/interfaces.go new file mode 100644 index 0000000..633c394 --- /dev/null +++ b/pkg/server/interfaces.go @@ -0,0 +1,137 @@ +package server + +import ( + "context" + "net/http" + "time" +) + +// Config holds the configuration for a single web server instance. +type Config struct { + Name string + Host string + Port int + Description string + + // Handler is the http.Handler (e.g., a router) to be served. + Handler http.Handler + + // GZIP compression support + GZIP bool + + // TLS/HTTPS configuration options (mutually exclusive) + // Option 1: Provide certificate and key files directly + SSLCert string + SSLKey string + + // Option 2: Use self-signed certificate (for development/testing) + // Generates a self-signed certificate automatically if no SSLCert/SSLKey provided + SelfSignedSSL bool + + // Option 3: Use Let's Encrypt / Certbot for automatic TLS + // AutoTLS enables automatic certificate management via Let's Encrypt + AutoTLS bool + // AutoTLSDomains specifies the domains for Let's Encrypt certificates + AutoTLSDomains []string + // AutoTLSCacheDir specifies where to cache certificates (default: "./certs-cache") + AutoTLSCacheDir string + // AutoTLSEmail is the email for Let's Encrypt registration (optional but recommended) + AutoTLSEmail string + + // Graceful shutdown configuration + // ShutdownTimeout is the maximum time to wait for graceful shutdown + // Default: 30 seconds + ShutdownTimeout time.Duration + + // DrainTimeout is the time to wait for in-flight requests to complete + // before forcing shutdown. Default: 25 seconds + DrainTimeout time.Duration + + // ReadTimeout is the maximum duration for reading the entire request + // Default: 15 seconds + ReadTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out writes of the response + // Default: 15 seconds + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the next request + // Default: 60 seconds + IdleTimeout time.Duration +} + +// Instance defines the interface for a single server instance. +// It abstracts the underlying http.Server, allowing for easier management and testing. +type Instance interface { + // Start begins serving requests. This method should be non-blocking and + // run the server in a separate goroutine. + Start() error + + // Stop gracefully shuts down the server without interrupting any active connections. + // It accepts a context to allow for a timeout. + Stop(ctx context.Context) error + + // Addr returns the network address the server is listening on. + Addr() string + + // Name returns the server instance name. + Name() string + + // HealthCheckHandler returns a handler that responds to health checks. + // Returns 200 OK when healthy, 503 Service Unavailable when shutting down. + HealthCheckHandler() http.HandlerFunc + + // ReadinessHandler returns a handler for readiness checks. + // Includes in-flight request count. + ReadinessHandler() http.HandlerFunc + + // InFlightRequests returns the current number of in-flight requests. + InFlightRequests() int64 + + // IsShuttingDown returns true if the server is shutting down. + IsShuttingDown() bool + + // Wait blocks until shutdown is complete. + Wait() +} + +// Manager defines the interface for a server manager. +// It is responsible for managing the lifecycle of multiple server instances. +type Manager interface { + // Add registers a new server instance based on the provided configuration. + // The server is not started until StartAll or Start is called on the instance. + Add(cfg Config) (Instance, error) + + // Get returns a server instance by its name. + Get(name string) (Instance, error) + + // Remove stops and removes a server instance by its name. + Remove(name string) error + + // StartAll starts all registered server instances that are not already running. + StartAll() error + + // StopAll gracefully shuts down all running server instances. + // Executes shutdown callbacks and drains in-flight requests. + StopAll() error + + // StopAllWithContext gracefully shuts down all running server instances with a context. + StopAllWithContext(ctx context.Context) error + + // RestartAll gracefully restarts all running server instances. + RestartAll() error + + // List returns all registered server instances. + List() []Instance + + // ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received. + // It handles SIGINT and SIGTERM signals and performs graceful shutdown with callbacks. + ServeWithGracefulShutdown() error + + // RegisterShutdownCallback registers a callback to be called during shutdown. + // Useful for cleanup tasks like closing database connections, flushing metrics, etc. + RegisterShutdownCallback(cb ShutdownCallback) +} + +// ShutdownCallback is a function called during graceful shutdown. +type ShutdownCallback func(context.Context) error diff --git a/pkg/server/manager.go b/pkg/server/manager.go new file mode 100644 index 0000000..a211dc3 --- /dev/null +++ b/pkg/server/manager.go @@ -0,0 +1,600 @@ +package server + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/middleware" + "github.com/klauspost/compress/gzhttp" +) + +// gracefulServer wraps http.Server with graceful shutdown capabilities (internal type) +type gracefulServer struct { + server *http.Server + shutdownTimeout time.Duration + drainTimeout time.Duration + inFlightRequests atomic.Int64 + isShuttingDown atomic.Bool + shutdownOnce sync.Once + shutdownComplete chan struct{} +} + +// trackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown +func (gs *gracefulServer) trackRequestsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if shutting down + if gs.isShuttingDown.Load() { + http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable) + return + } + + // Increment in-flight counter + gs.inFlightRequests.Add(1) + defer gs.inFlightRequests.Add(-1) + + // Serve the request + next.ServeHTTP(w, r) + }) +} + +// shutdown performs graceful shutdown with request draining +func (gs *gracefulServer) shutdown(ctx context.Context) error { + var shutdownErr error + + gs.shutdownOnce.Do(func() { + logger.Info("Starting graceful shutdown...") + + // Mark as shutting down (new requests will be rejected) + gs.isShuttingDown.Store(true) + + // Create context with timeout + shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout) + defer cancel() + + // Wait for in-flight requests to complete (with drain timeout) + drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout) + defer drainCancel() + + shutdownErr = gs.drainRequests(drainCtx) + if shutdownErr != nil { + logger.Error("Error draining requests: %v", shutdownErr) + } + + // Shutdown the server + logger.Info("Shutting down HTTP server...") + if err := gs.server.Shutdown(shutdownCtx); err != nil { + logger.Error("Error shutting down server: %v", err) + if shutdownErr == nil { + shutdownErr = err + } + } + + logger.Info("Graceful shutdown complete") + close(gs.shutdownComplete) + }) + + return shutdownErr +} + +// drainRequests waits for in-flight requests to complete +func (gs *gracefulServer) drainRequests(ctx context.Context) error { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + startTime := time.Now() + + for { + inFlight := gs.inFlightRequests.Load() + + if inFlight == 0 { + logger.Info("All requests drained in %v", time.Since(startTime)) + return nil + } + + select { + case <-ctx.Done(): + logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight) + return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight) + case <-ticker.C: + logger.Debug("Waiting for %d in-flight requests to complete...", inFlight) + } + } +} + +// inFlightRequests returns the current number of in-flight requests +func (gs *gracefulServer) inFlightRequestsCount() int64 { + return gs.inFlightRequests.Load() +} + +// isShutdown returns true if the server is shutting down +func (gs *gracefulServer) isShutdown() bool { + return gs.isShuttingDown.Load() +} + +// wait blocks until shutdown is complete +func (gs *gracefulServer) wait() { + <-gs.shutdownComplete +} + +// healthCheckHandler returns a handler that responds to health checks +func (gs *gracefulServer) healthCheckHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if gs.isShutdown() { + http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"status":"healthy"}`)) + if err != nil { + logger.Warn("Failed to write health check response: %v", err) + } + } +} + +// readinessHandler returns a handler for readiness checks +func (gs *gracefulServer) readinessHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if gs.isShutdown() { + http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable) + return + } + + inFlight := gs.inFlightRequestsCount() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight) + } +} + +// serverManager manages a collection of server instances with graceful shutdown support. +type serverManager struct { + instances map[string]Instance + mu sync.RWMutex + shutdownCallbacks []ShutdownCallback + callbacksMu sync.Mutex +} + +// NewManager creates a new server manager. +func NewManager() Manager { + return &serverManager{ + instances: make(map[string]Instance), + shutdownCallbacks: make([]ShutdownCallback, 0), + } +} + +// Add registers a new server instance. +func (sm *serverManager) Add(cfg Config) (Instance, error) { + sm.mu.Lock() + defer sm.mu.Unlock() + + if cfg.Name == "" { + return nil, fmt.Errorf("server name cannot be empty") + } + if _, exists := sm.instances[cfg.Name]; exists { + return nil, fmt.Errorf("server with name '%s' already exists", cfg.Name) + } + + instance, err := newInstance(cfg) + if err != nil { + return nil, err + } + + sm.instances[cfg.Name] = instance + return instance, nil +} + +// Get returns a server instance by its name. +func (sm *serverManager) Get(name string) (Instance, error) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + instance, exists := sm.instances[name] + if !exists { + return nil, fmt.Errorf("server with name '%s' not found", name) + } + return instance, nil +} + +// Remove stops and removes a server instance by its name. +func (sm *serverManager) Remove(name string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + instance, exists := sm.instances[name] + if !exists { + return fmt.Errorf("server with name '%s' not found", name) + } + + // Stop the server if it's running. Prefer the server's configured shutdownTimeout + // when available, and fall back to a sensible default. + timeout := 10 * time.Second + if si, ok := instance.(*serverInstance); ok && si.gracefulServer != nil && si.gracefulServer.shutdownTimeout > 0 { + timeout = si.gracefulServer.shutdownTimeout + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + if err := instance.Stop(ctx); err != nil { + logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err) + } + + delete(sm.instances, name) + return nil +} + +// StartAll starts all registered server instances. +func (sm *serverManager) StartAll() error { + sm.mu.RLock() + defer sm.mu.RUnlock() + + var startErrors []error + for name, instance := range sm.instances { + if err := instance.Start(); err != nil { + startErrors = append(startErrors, fmt.Errorf("failed to start server '%s': %w", name, err)) + } + } + + if len(startErrors) > 0 { + return fmt.Errorf("encountered errors while starting servers: %v", startErrors) + } + return nil +} + +// StopAll gracefully shuts down all running server instances. +func (sm *serverManager) StopAll() error { + return sm.StopAllWithContext(context.Background()) +} + +// StopAllWithContext gracefully shuts down all running server instances with a context. +func (sm *serverManager) StopAllWithContext(ctx context.Context) error { + sm.mu.RLock() + instancesToStop := make([]Instance, 0, len(sm.instances)) + for _, instance := range sm.instances { + instancesToStop = append(instancesToStop, instance) + } + sm.mu.RUnlock() + + logger.Info("Shutting down all servers...") + + // Execute shutdown callbacks first + sm.callbacksMu.Lock() + callbacks := make([]ShutdownCallback, len(sm.shutdownCallbacks)) + copy(callbacks, sm.shutdownCallbacks) + sm.callbacksMu.Unlock() + + if len(callbacks) > 0 { + logger.Info("Executing %d shutdown callbacks...", len(callbacks)) + for i, cb := range callbacks { + if err := cb(ctx); err != nil { + logger.Error("Shutdown callback %d failed: %v", i+1, err) + } + } + } + + // Stop all instances in parallel + var shutdownErrors []error + var wg sync.WaitGroup + var errorsMu sync.Mutex + + for _, instance := range instancesToStop { + wg.Add(1) + go func(inst Instance) { + defer wg.Done() + shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + if err := inst.Stop(shutdownCtx); err != nil { + errorsMu.Lock() + shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Name(), err)) + errorsMu.Unlock() + } + }(instance) + } + + wg.Wait() + + if len(shutdownErrors) > 0 { + return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors) + } + logger.Info("All servers stopped gracefully.") + return nil +} + +// RestartAll gracefully restarts all running server instances. +func (sm *serverManager) RestartAll() error { + logger.Info("Restarting all servers...") + if err := sm.StopAll(); err != nil { + return fmt.Errorf("failed to stop servers during restart: %w", err) + } + + // Retry starting all servers with exponential backoff instead of a fixed sleep. + const ( + maxAttempts = 5 + initialBackoff = 100 * time.Millisecond + maxBackoff = 2 * time.Second + ) + + var lastErr error + backoff := initialBackoff + + for attempt := 1; attempt <= maxAttempts; attempt++ { + if err := sm.StartAll(); err != nil { + lastErr = err + if attempt == maxAttempts { + break + } + logger.Warn("Attempt %d to start servers during restart failed: %v; retrying in %s", attempt, err, backoff) + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } + + logger.Info("All servers restarted successfully.") + return nil + } + + return fmt.Errorf("failed to start servers during restart after %d attempts: %w", maxAttempts, lastErr) +} + +// List returns all registered server instances. +func (sm *serverManager) List() []Instance { + sm.mu.RLock() + defer sm.mu.RUnlock() + + instances := make([]Instance, 0, len(sm.instances)) + for _, instance := range sm.instances { + instances = append(instances, instance) + } + return instances +} + +// RegisterShutdownCallback registers a callback to be called during shutdown. +func (sm *serverManager) RegisterShutdownCallback(cb ShutdownCallback) { + sm.callbacksMu.Lock() + defer sm.callbacksMu.Unlock() + sm.shutdownCallbacks = append(sm.shutdownCallbacks, cb) +} + +// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received. +func (sm *serverManager) ServeWithGracefulShutdown() error { + // Start all servers + if err := sm.StartAll(); err != nil { + return fmt.Errorf("failed to start servers: %w", err) + } + + logger.Info("All servers started. Waiting for shutdown signal...") + + // Wait for interrupt signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + + sig := <-sigChan + logger.Info("Received signal: %v, initiating graceful shutdown", sig) + + // Create context with timeout for shutdown + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + return sm.StopAllWithContext(ctx) +} + +// serverInstance is a concrete implementation of the Instance interface. +// It wraps gracefulServer to provide graceful shutdown capabilities. +type serverInstance struct { + cfg Config + gracefulServer *gracefulServer + certFile string // Path to certificate file (may be persistent for self-signed) + keyFile string // Path to key file (may be persistent for self-signed) + mu sync.RWMutex + running bool + serverErr chan error +} + +// newInstance creates a new, unstarted server instance from a config. +func newInstance(cfg Config) (*serverInstance, error) { + if cfg.Handler == nil { + return nil, fmt.Errorf("handler cannot be nil") + } + + // Set default timeouts + if cfg.ShutdownTimeout == 0 { + cfg.ShutdownTimeout = 30 * time.Second + } + if cfg.DrainTimeout == 0 { + cfg.DrainTimeout = 25 * time.Second + } + if cfg.ReadTimeout == 0 { + cfg.ReadTimeout = 15 * time.Second + } + if cfg.WriteTimeout == 0 { + cfg.WriteTimeout = 15 * time.Second + } + if cfg.IdleTimeout == 0 { + cfg.IdleTimeout = 60 * time.Second + } + + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + var handler http.Handler = cfg.Handler + + // Wrap with GZIP handler if enabled + if cfg.GZIP { + gz, err := gzhttp.NewWrapper() + if err != nil { + return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err) + } + handler = gz(handler) + } + + // Wrap with the panic recovery middleware + handler = middleware.PanicRecovery(handler) + + // Configure TLS if any TLS option is enabled + tlsConfig, certFile, keyFile, err := configureTLS(cfg) + if err != nil { + return nil, fmt.Errorf("failed to configure TLS: %w", err) + } + + // Create gracefulServer + gracefulSrv := &gracefulServer{ + server: &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + IdleTimeout: cfg.IdleTimeout, + TLSConfig: tlsConfig, + }, + shutdownTimeout: cfg.ShutdownTimeout, + drainTimeout: cfg.DrainTimeout, + shutdownComplete: make(chan struct{}), + } + + return &serverInstance{ + cfg: cfg, + gracefulServer: gracefulSrv, + certFile: certFile, + keyFile: keyFile, + serverErr: make(chan error, 1), + }, nil +} + +// Start begins serving requests in a new goroutine. +func (s *serverInstance) Start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.running { + return fmt.Errorf("server '%s' is already running", s.cfg.Name) + } + + // Determine if we're using TLS + useTLS := s.cfg.SSLCert != "" || s.cfg.SSLKey != "" || s.cfg.SelfSignedSSL || s.cfg.AutoTLS + + // Wrap handler with request tracking + s.gracefulServer.server.Handler = s.gracefulServer.trackRequestsMiddleware(s.gracefulServer.server.Handler) + + go func() { + defer func() { + s.mu.Lock() + s.running = false + s.mu.Unlock() + logger.Info("Server '%s' stopped.", s.cfg.Name) + }() + + var err error + protocol := "HTTP" + + if useTLS { + protocol = "HTTPS" + logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr()) + + // For AutoTLS, we need to use a TLS listener + if s.cfg.AutoTLS { + // Create listener + ln, lnErr := net.Listen("tcp", s.gracefulServer.server.Addr) + if lnErr != nil { + err = fmt.Errorf("failed to create listener: %w", lnErr) + } else { + // Wrap with TLS + tlsListener := tls.NewListener(ln, s.gracefulServer.server.TLSConfig) + err = s.gracefulServer.server.Serve(tlsListener) + } + } else { + // Use certificate files (regular SSL or self-signed) + err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile) + } + } else { + logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr()) + err = s.gracefulServer.server.ListenAndServe() + } + + // If the server stopped for a reason other than a graceful shutdown, log and report the error. + if err != nil && err != http.ErrServerClosed { + logger.Error("Server '%s' failed: %v", s.cfg.Name, err) + select { + case s.serverErr <- err: + default: + } + } + }() + + s.running = true + // A small delay to allow the goroutine to start and potentially fail on binding. + time.Sleep(50 * time.Millisecond) + + // Check if the server failed to start + select { + case err := <-s.serverErr: + s.running = false + return err + default: + } + + return nil +} + +// Stop gracefully shuts down the server. +func (s *serverInstance) Stop(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.running { + return nil // Already stopped + } + + logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name) + err := s.gracefulServer.shutdown(ctx) + if err == nil { + s.running = false + } + return err +} + +// Addr returns the network address the server is listening on. +func (s *serverInstance) Addr() string { + return s.gracefulServer.server.Addr +} + +// Name returns the server instance name. +func (s *serverInstance) Name() string { + return s.cfg.Name +} + +// HealthCheckHandler returns a handler that responds to health checks. +func (s *serverInstance) HealthCheckHandler() http.HandlerFunc { + return s.gracefulServer.healthCheckHandler() +} + +// ReadinessHandler returns a handler for readiness checks. +func (s *serverInstance) ReadinessHandler() http.HandlerFunc { + return s.gracefulServer.readinessHandler() +} + +// InFlightRequests returns the current number of in-flight requests. +func (s *serverInstance) InFlightRequests() int64 { + return s.gracefulServer.inFlightRequestsCount() +} + +// IsShuttingDown returns true if the server is shutting down. +func (s *serverInstance) IsShuttingDown() bool { + return s.gracefulServer.isShutdown() +} + +// Wait blocks until shutdown is complete. +func (s *serverInstance) Wait() { + s.gracefulServer.wait() +} diff --git a/pkg/server/manager_test.go b/pkg/server/manager_test.go new file mode 100644 index 0000000..e2b35de --- /dev/null +++ b/pkg/server/manager_test.go @@ -0,0 +1,399 @@ +package server + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// getFreePort asks the kernel for a free open port that is ready to use. +func getFreePort(t *testing.T) int { + t.Helper() + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + require.NoError(t, err) + + l, err := net.ListenTCP("tcp", addr) + require.NoError(t, err) + defer l.Close() + return l.Addr().(*net.TCPAddr).Port +} + +func TestServerManagerLifecycle(t *testing.T) { + // Initialize logger for test output + logger.Init(true) + + // Create a new server manager + sm := NewManager() + + // Define a simple test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello, World!")) + }) + + // Get a free port for the server to listen on to avoid conflicts + testPort := getFreePort(t) + + // Add a new server configuration + serverConfig := Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: testHandler, + } + instance, err := sm.Add(serverConfig) + require.NoError(t, err, "should be able to add a new server") + require.NotNil(t, instance, "added instance should not be nil") + + // --- Test StartAll --- + err = sm.StartAll() + require.NoError(t, err, "StartAll should not return an error") + + // Give the server a moment to start up + time.Sleep(100 * time.Millisecond) + + // --- Verify Server is Running --- + client := &http.Client{Timeout: 2 * time.Second} + url := fmt.Sprintf("http://localhost:%d", testPort) + resp, err := client.Get(url) + require.NoError(t, err, "should be able to make a request to the running server") + + assert.Equal(t, http.StatusOK, resp.StatusCode, "expected status OK from the test server") + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, "Hello, World!", string(body), "response body should match expected value") + + // --- Test Get --- + retrievedInstance, err := sm.Get("TestServer") + require.NoError(t, err, "should be able to get server by name") + assert.Equal(t, instance.Addr(), retrievedInstance.Addr(), "retrieved instance should be the same") + + // --- Test List --- + instanceList := sm.List() + require.Len(t, instanceList, 1, "list should contain one instance") + assert.Equal(t, instance.Addr(), instanceList[0].Addr(), "listed instance should be the same") + + // --- Test StopAll --- + err = sm.StopAll() + require.NoError(t, err, "StopAll should not return an error") + + // Give the server a moment to shut down + time.Sleep(100 * time.Millisecond) + + // --- Verify Server is Stopped --- + _, err = client.Get(url) + require.Error(t, err, "should not be able to make a request to a stopped server") + + // --- Test Remove --- + err = sm.Remove("TestServer") + require.NoError(t, err, "should be able to remove a server") + + _, err = sm.Get("TestServer") + require.Error(t, err, "should not be able to get a removed server") +} + +func TestManagerErrorCases(t *testing.T) { + logger.Init(true) + sm := NewManager() + testPort := getFreePort(t) + + // --- Test Add Duplicate Name --- + config1 := Config{Name: "Duplicate", Host: "localhost", Port: testPort, Handler: http.NewServeMux()} + _, err := sm.Add(config1) + require.NoError(t, err) + + config2 := Config{Name: "Duplicate", Host: "localhost", Port: getFreePort(t), Handler: http.NewServeMux()} + _, err = sm.Add(config2) + require.Error(t, err, "should not be able to add a server with a duplicate name") + + // --- Test Get Non-existent --- + _, err = sm.Get("NonExistent") + require.Error(t, err, "should get an error for a non-existent server") + + // --- Test Add with Nil Handler --- + config3 := Config{Name: "NilHandler", Host: "localhost", Port: getFreePort(t), Handler: nil} + _, err = sm.Add(config3) + require.Error(t, err, "should not be able to add a server with a nil handler") +} + +func TestGracefulShutdown(t *testing.T) { + logger.Init(true) + sm := NewManager() + + requestsHandled := 0 + var requestsMu sync.Mutex + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestsMu.Lock() + requestsHandled++ + requestsMu.Unlock() + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }) + + testPort := getFreePort(t) + instance, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: handler, + DrainTimeout: 2 * time.Second, + }) + require.NoError(t, err) + + err = sm.StartAll() + require.NoError(t, err) + + // Give server time to start + time.Sleep(100 * time.Millisecond) + + // Send some concurrent requests + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + client := &http.Client{Timeout: 5 * time.Second} + url := fmt.Sprintf("http://localhost:%d", testPort) + resp, err := client.Get(url) + if err == nil { + resp.Body.Close() + } + }() + } + + // Wait a bit for requests to start + time.Sleep(50 * time.Millisecond) + + // Check in-flight requests + inFlight := instance.InFlightRequests() + assert.Greater(t, inFlight, int64(0), "Should have in-flight requests") + + // Stop the server + err = sm.StopAll() + require.NoError(t, err) + + // Wait for all requests to complete + wg.Wait() + + // Verify all requests were handled + requestsMu.Lock() + handled := requestsHandled + requestsMu.Unlock() + assert.GreaterOrEqual(t, handled, 1, "At least some requests should have been handled") + + // Verify no in-flight requests + assert.Equal(t, int64(0), instance.InFlightRequests(), "Should have no in-flight requests after shutdown") +} + +func TestHealthAndReadinessEndpoints(t *testing.T) { + logger.Init(true) + sm := NewManager() + + mux := http.NewServeMux() + testPort := getFreePort(t) + + instance, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: mux, + }) + require.NoError(t, err) + + // Add health and readiness endpoints + mux.HandleFunc("/health", instance.HealthCheckHandler()) + mux.HandleFunc("/ready", instance.ReadinessHandler()) + + err = sm.StartAll() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + client := &http.Client{Timeout: 2 * time.Second} + baseURL := fmt.Sprintf("http://localhost:%d", testPort) + + // Test health endpoint + resp, err := client.Get(baseURL + "/health") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + assert.Contains(t, string(body), "healthy") + + // Test readiness endpoint + resp, err = client.Get(baseURL + "/ready") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, _ = io.ReadAll(resp.Body) + resp.Body.Close() + assert.Contains(t, string(body), "ready") + assert.Contains(t, string(body), "in_flight_requests") + + // Stop the server + sm.StopAll() +} + +func TestRequestRejectionDuringShutdown(t *testing.T) { + logger.Init(true) + sm := NewManager() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }) + + testPort := getFreePort(t) + _, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: handler, + DrainTimeout: 1 * time.Second, + }) + require.NoError(t, err) + + err = sm.StartAll() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Start shutdown in background + go func() { + time.Sleep(50 * time.Millisecond) + sm.StopAll() + }() + + // Give shutdown time to start + time.Sleep(100 * time.Millisecond) + + // Try to make a request after shutdown started + client := &http.Client{Timeout: 2 * time.Second} + url := fmt.Sprintf("http://localhost:%d", testPort) + resp, err := client.Get(url) + + // The request should either fail (connection refused) or get 503 + if err == nil { + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Should get 503 during shutdown") + resp.Body.Close() + } +} + +func TestShutdownCallbacks(t *testing.T) { + logger.Init(true) + sm := NewManager() + + callbackExecuted := false + var callbackMu sync.Mutex + + sm.RegisterShutdownCallback(func(ctx context.Context) error { + callbackMu.Lock() + callbackExecuted = true + callbackMu.Unlock() + return nil + }) + + testPort := getFreePort(t) + _, err := sm.Add(Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: http.NewServeMux(), + }) + require.NoError(t, err) + + err = sm.StartAll() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = sm.StopAll() + require.NoError(t, err) + + callbackMu.Lock() + executed := callbackExecuted + callbackMu.Unlock() + + assert.True(t, executed, "Shutdown callback should have been executed") +} + +func TestSelfSignedSSLCertificateReuse(t *testing.T) { + logger.Init(true) + + // Get expected cert directory location + cacheDir, err := os.UserCacheDir() + require.NoError(t, err) + certDir := filepath.Join(cacheDir, "resolvespec", "certs") + + host := "localhost" + certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host)) + keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host)) + + // Clean up any existing cert files from previous tests + os.Remove(certFile) + os.Remove(keyFile) + + // First server creation - should generate new certificates + sm1 := NewManager() + testPort1 := getFreePort(t) + _, err = sm1.Add(Config{ + Name: "SSLTestServer1", + Host: host, + Port: testPort1, + Handler: http.NewServeMux(), + SelfSignedSSL: true, + ShutdownTimeout: 5 * time.Second, + }) + require.NoError(t, err) + + // Verify certificates were created + _, err = os.Stat(certFile) + require.NoError(t, err, "certificate file should exist after first creation") + _, err = os.Stat(keyFile) + require.NoError(t, err, "key file should exist after first creation") + + // Get modification time of cert file + info1, err := os.Stat(certFile) + require.NoError(t, err) + modTime1 := info1.ModTime() + + // Wait a bit to ensure different modification times + time.Sleep(100 * time.Millisecond) + + // Second server creation - should reuse existing certificates + sm2 := NewManager() + testPort2 := getFreePort(t) + _, err = sm2.Add(Config{ + Name: "SSLTestServer2", + Host: host, + Port: testPort2, + Handler: http.NewServeMux(), + SelfSignedSSL: true, + ShutdownTimeout: 5 * time.Second, + }) + require.NoError(t, err) + + // Get modification time of cert file after second creation + info2, err := os.Stat(certFile) + require.NoError(t, err) + modTime2 := info2.ModTime() + + // Verify the certificate was reused (same modification time) + assert.Equal(t, modTime1, modTime2, "certificate should be reused, not regenerated") + + // Clean up + sm1.StopAll() + sm2.StopAll() +} diff --git a/pkg/server/shutdown.go b/pkg/server/shutdown.go deleted file mode 100644 index 9960405..0000000 --- a/pkg/server/shutdown.go +++ /dev/null @@ -1,296 +0,0 @@ -package server - -import ( - "context" - "fmt" - "net/http" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/bitechdev/ResolveSpec/pkg/logger" -) - -// GracefulServer wraps http.Server with graceful shutdown capabilities -type GracefulServer struct { - server *http.Server - shutdownTimeout time.Duration - drainTimeout time.Duration - inFlightRequests atomic.Int64 - isShuttingDown atomic.Bool - shutdownOnce sync.Once - shutdownComplete chan struct{} -} - -// Config holds configuration for the graceful server -type Config struct { - // Addr is the server address (e.g., ":8080") - Addr string - - // Handler is the HTTP handler - Handler http.Handler - - // ShutdownTimeout is the maximum time to wait for graceful shutdown - // Default: 30 seconds - ShutdownTimeout time.Duration - - // DrainTimeout is the time to wait for in-flight requests to complete - // before forcing shutdown. Default: 25 seconds - DrainTimeout time.Duration - - // ReadTimeout is the maximum duration for reading the entire request - ReadTimeout time.Duration - - // WriteTimeout is the maximum duration before timing out writes of the response - WriteTimeout time.Duration - - // IdleTimeout is the maximum amount of time to wait for the next request - IdleTimeout time.Duration -} - -// NewGracefulServer creates a new graceful server -func NewGracefulServer(config Config) *GracefulServer { - if config.ShutdownTimeout == 0 { - config.ShutdownTimeout = 30 * time.Second - } - if config.DrainTimeout == 0 { - config.DrainTimeout = 25 * time.Second - } - if config.ReadTimeout == 0 { - config.ReadTimeout = 10 * time.Second - } - if config.WriteTimeout == 0 { - config.WriteTimeout = 10 * time.Second - } - if config.IdleTimeout == 0 { - config.IdleTimeout = 120 * time.Second - } - - gs := &GracefulServer{ - server: &http.Server{ - Addr: config.Addr, - Handler: config.Handler, - ReadTimeout: config.ReadTimeout, - WriteTimeout: config.WriteTimeout, - IdleTimeout: config.IdleTimeout, - }, - shutdownTimeout: config.ShutdownTimeout, - drainTimeout: config.DrainTimeout, - shutdownComplete: make(chan struct{}), - } - - return gs -} - -// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown -func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if shutting down - if gs.isShuttingDown.Load() { - http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable) - return - } - - // Increment in-flight counter - gs.inFlightRequests.Add(1) - defer gs.inFlightRequests.Add(-1) - - // Serve the request - next.ServeHTTP(w, r) - }) -} - -// ListenAndServe starts the server and handles graceful shutdown -func (gs *GracefulServer) ListenAndServe() error { - // Wrap handler with request tracking - gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler) - - // Start server in goroutine - serverErr := make(chan error, 1) - go func() { - logger.Info("Starting server on %s", gs.server.Addr) - if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - serverErr <- err - } - close(serverErr) - }() - - // Wait for interrupt signal - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) - - select { - case err := <-serverErr: - return err - case sig := <-sigChan: - logger.Info("Received signal: %v, initiating graceful shutdown", sig) - return gs.Shutdown(context.Background()) - } -} - -// Shutdown performs graceful shutdown with request draining -func (gs *GracefulServer) Shutdown(ctx context.Context) error { - var shutdownErr error - - gs.shutdownOnce.Do(func() { - logger.Info("Starting graceful shutdown...") - - // Mark as shutting down (new requests will be rejected) - gs.isShuttingDown.Store(true) - - // Create context with timeout - shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout) - defer cancel() - - // Wait for in-flight requests to complete (with drain timeout) - drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout) - defer drainCancel() - - shutdownErr = gs.drainRequests(drainCtx) - if shutdownErr != nil { - logger.Error("Error draining requests: %v", shutdownErr) - } - - // Shutdown the server - logger.Info("Shutting down HTTP server...") - if err := gs.server.Shutdown(shutdownCtx); err != nil { - logger.Error("Error shutting down server: %v", err) - if shutdownErr == nil { - shutdownErr = err - } - } - - logger.Info("Graceful shutdown complete") - close(gs.shutdownComplete) - }) - - return shutdownErr -} - -// drainRequests waits for in-flight requests to complete -func (gs *GracefulServer) drainRequests(ctx context.Context) error { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - startTime := time.Now() - - for { - inFlight := gs.inFlightRequests.Load() - - if inFlight == 0 { - logger.Info("All requests drained in %v", time.Since(startTime)) - return nil - } - - select { - case <-ctx.Done(): - logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight) - return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight) - case <-ticker.C: - logger.Debug("Waiting for %d in-flight requests to complete...", inFlight) - } - } -} - -// InFlightRequests returns the current number of in-flight requests -func (gs *GracefulServer) InFlightRequests() int64 { - return gs.inFlightRequests.Load() -} - -// IsShuttingDown returns true if the server is shutting down -func (gs *GracefulServer) IsShuttingDown() bool { - return gs.isShuttingDown.Load() -} - -// Wait blocks until shutdown is complete -func (gs *GracefulServer) Wait() { - <-gs.shutdownComplete -} - -// HealthCheckHandler returns a handler that responds to health checks -// Returns 200 OK when healthy, 503 Service Unavailable when shutting down -func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if gs.IsShuttingDown() { - http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(`{"status":"healthy"}`)) - if err != nil { - logger.Warn("Failed to write. %v", err) - } - } -} - -// ReadinessHandler returns a handler for readiness checks -// Includes in-flight request count -func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if gs.IsShuttingDown() { - http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable) - return - } - - inFlight := gs.InFlightRequests() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight) - } -} - -// ShutdownCallback is a function called during shutdown -type ShutdownCallback func(context.Context) error - -// shutdownCallbacks stores registered shutdown callbacks -var ( - shutdownCallbacks []ShutdownCallback - shutdownCallbacksMu sync.Mutex -) - -// RegisterShutdownCallback registers a callback to be called during shutdown -// Useful for cleanup tasks like closing database connections, flushing metrics, etc. -func RegisterShutdownCallback(cb ShutdownCallback) { - shutdownCallbacksMu.Lock() - defer shutdownCallbacksMu.Unlock() - shutdownCallbacks = append(shutdownCallbacks, cb) -} - -// executeShutdownCallbacks runs all registered shutdown callbacks -func executeShutdownCallbacks(ctx context.Context) error { - shutdownCallbacksMu.Lock() - callbacks := make([]ShutdownCallback, len(shutdownCallbacks)) - copy(callbacks, shutdownCallbacks) - shutdownCallbacksMu.Unlock() - - var errors []error - for i, cb := range callbacks { - logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks)) - if err := cb(ctx); err != nil { - logger.Error("Shutdown callback %d failed: %v", i+1, err) - errors = append(errors, err) - } - } - - if len(errors) > 0 { - return fmt.Errorf("shutdown callbacks failed: %v", errors) - } - - return nil -} - -// ShutdownWithCallbacks performs shutdown and executes all registered callbacks -func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error { - // Execute callbacks first - if err := executeShutdownCallbacks(ctx); err != nil { - logger.Error("Error executing shutdown callbacks: %v", err) - } - - // Then shutdown the server - return gs.Shutdown(ctx) -} diff --git a/pkg/server/shutdown_test.go b/pkg/server/shutdown_test.go deleted file mode 100644 index 8eef439..0000000 --- a/pkg/server/shutdown_test.go +++ /dev/null @@ -1,231 +0,0 @@ -package server - -import ( - "context" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" -) - -func TestGracefulServerTrackRequests(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusOK) - }), - }) - - handler := srv.TrackRequestsMiddleware(srv.server.Handler) - - // Start some requests - var wg sync.WaitGroup - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - }() - } - - // Wait a bit for requests to start - time.Sleep(10 * time.Millisecond) - - // Check in-flight count - inFlight := srv.InFlightRequests() - if inFlight == 0 { - t.Error("Should have in-flight requests") - } - - // Wait for all requests to complete - wg.Wait() - - // Check that counter is back to zero - inFlight = srv.InFlightRequests() - if inFlight != 0 { - t.Errorf("In-flight requests should be 0, got %d", inFlight) - } -} - -func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }), - }) - - handler := srv.TrackRequestsMiddleware(srv.server.Handler) - - // Mark as shutting down - srv.isShuttingDown.Store(true) - - // Try to make a request - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - // Should get 503 - if w.Code != http.StatusServiceUnavailable { - t.Errorf("Expected 503, got %d", w.Code) - } -} - -func TestHealthCheckHandler(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - }) - - handler := srv.HealthCheckHandler() - - // Healthy - t.Run("Healthy", func(t *testing.T) { - req := httptest.NewRequest("GET", "/health", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected 200, got %d", w.Code) - } - - if w.Body.String() != `{"status":"healthy"}` { - t.Errorf("Unexpected body: %s", w.Body.String()) - } - }) - - // Shutting down - t.Run("ShuttingDown", func(t *testing.T) { - srv.isShuttingDown.Store(true) - - req := httptest.NewRequest("GET", "/health", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusServiceUnavailable { - t.Errorf("Expected 503, got %d", w.Code) - } - }) -} - -func TestReadinessHandler(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - }) - - handler := srv.ReadinessHandler() - - // Ready with no in-flight requests - t.Run("Ready", func(t *testing.T) { - req := httptest.NewRequest("GET", "/ready", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected 200, got %d", w.Code) - } - - body := w.Body.String() - if body != `{"ready":true,"in_flight_requests":0}` { - t.Errorf("Unexpected body: %s", body) - } - }) - - // Not ready during shutdown - t.Run("NotReady", func(t *testing.T) { - srv.isShuttingDown.Store(true) - - req := httptest.NewRequest("GET", "/ready", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusServiceUnavailable { - t.Errorf("Expected 503, got %d", w.Code) - } - }) -} - -func TestShutdownCallbacks(t *testing.T) { - callbackExecuted := false - - RegisterShutdownCallback(func(ctx context.Context) error { - callbackExecuted = true - return nil - }) - - ctx := context.Background() - err := executeShutdownCallbacks(ctx) - - if err != nil { - t.Errorf("executeShutdownCallbacks() error = %v", err) - } - - if !callbackExecuted { - t.Error("Shutdown callback was not executed") - } - - // Reset for other tests - shutdownCallbacks = nil -} - -func TestDrainRequests(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - DrainTimeout: 1 * time.Second, - }) - - // Simulate in-flight requests - srv.inFlightRequests.Add(3) - - // Start draining in background - go func() { - time.Sleep(100 * time.Millisecond) - // Simulate requests completing - srv.inFlightRequests.Add(-3) - }() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := srv.drainRequests(ctx) - if err != nil { - t.Errorf("drainRequests() error = %v", err) - } - - if srv.InFlightRequests() != 0 { - t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests()) - } -} - -func TestDrainRequestsTimeout(t *testing.T) { - srv := NewGracefulServer(Config{ - Addr: ":0", - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - DrainTimeout: 100 * time.Millisecond, - }) - - // Simulate in-flight requests that don't complete - srv.inFlightRequests.Add(5) - - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - - err := srv.drainRequests(ctx) - if err == nil { - t.Error("drainRequests() should timeout with error") - } - - // Cleanup - srv.inFlightRequests.Add(-5) -} - -func TestGetClientIP(t *testing.T) { - // This test is in ratelimit_test.go since getClientIP is used by rate limiter - // Including here for completeness of server tests -} diff --git a/pkg/server/tls.go b/pkg/server/tls.go new file mode 100644 index 0000000..a2a308d --- /dev/null +++ b/pkg/server/tls.go @@ -0,0 +1,294 @@ +package server + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "path/filepath" + "sync" + "time" + + "golang.org/x/crypto/acme/autocert" +) + +// certGenerationMutex protects concurrent certificate generation for the same host +var certGenerationMutex sync.Mutex + +// generateSelfSignedCert generates a self-signed certificate for the given host. +// Returns the certificate and private key in PEM format. +func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) { + // Generate private key + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + // Create certificate template + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"ResolveSpec Self-Signed"}, + CommonName: host, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Add host as DNS name or IP address + if ip := net.ParseIP(host); ip != nil { + template.IPAddresses = []net.IP{ip} + } else { + template.DNSNames = []string{host} + } + + // Create certificate + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Encode certificate to PEM + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + // Encode private key to PEM + privBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) + + return certPEM, keyPEM, nil +} + +// sanitizeHostname converts a hostname to a safe filename by replacing invalid characters. +func sanitizeHostname(host string) string { + // Replace any character that's not alphanumeric, dot, or dash with underscore + safe := "" + for _, r := range host { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '-' { + safe += string(r) + } else { + safe += "_" + } + } + return safe +} + +// getCertDirectory returns the directory path for storing self-signed certificates. +// Creates the directory if it doesn't exist. +func getCertDirectory() (string, error) { + // Use a consistent directory in the user's cache directory + cacheDir, err := os.UserCacheDir() + if err != nil { + // Fallback to current directory if cache dir is not available + cacheDir = "." + } + + certDir := filepath.Join(cacheDir, "resolvespec", "certs") + + // Create directory if it doesn't exist + if err := os.MkdirAll(certDir, 0700); err != nil { + return "", fmt.Errorf("failed to create certificate directory: %w", err) + } + + return certDir, nil +} + +// isCertificateValid checks if a certificate file exists and is not expired. +func isCertificateValid(certFile string) bool { + // Check if file exists + certData, err := os.ReadFile(certFile) + if err != nil { + return false + } + + // Parse certificate + block, _ := pem.Decode(certData) + if block == nil { + return false + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false + } + + // Check if certificate is expired or will expire in the next 30 days + now := time.Now() + expiryThreshold := now.Add(30 * 24 * time.Hour) + + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return false + } + + // Renew if expiring soon + if expiryThreshold.After(cert.NotAfter) { + return false + } + + return true +} + +// saveCertToFiles saves certificate and key PEM data to persistent files. +// Returns the file paths for the certificate and key. +func saveCertToFiles(certPEM, keyPEM []byte, host string) (certFile, keyFile string, err error) { + // Get certificate directory + certDir, err := getCertDirectory() + if err != nil { + return "", "", err + } + + // Sanitize hostname for safe file naming + safeHost := sanitizeHostname(host) + + // Use consistent file names based on host + certFile = filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", safeHost)) + keyFile = filepath.Join(certDir, fmt.Sprintf("%s-key.pem", safeHost)) + + // Write certificate + if err := os.WriteFile(certFile, certPEM, 0600); err != nil { + return "", "", fmt.Errorf("failed to write certificate: %w", err) + } + + // Write key + if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { + return "", "", fmt.Errorf("failed to write private key: %w", err) + } + + return certFile, keyFile, nil +} + +// setupAutoTLS configures automatic TLS certificate management using Let's Encrypt. +// Returns a TLS config that can be used with http.Server. +func setupAutoTLS(domains []string, email, cacheDir string) (*tls.Config, error) { + if len(domains) == 0 { + return nil, fmt.Errorf("at least one domain must be specified for AutoTLS") + } + + // Set default cache directory + if cacheDir == "" { + cacheDir = "./certs-cache" + } + + // Create cache directory if it doesn't exist + if err := os.MkdirAll(cacheDir, 0700); err != nil { + return nil, fmt.Errorf("failed to create certificate cache directory: %w", err) + } + + // Create autocert manager + m := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(cacheDir), + HostPolicy: autocert.HostWhitelist(domains...), + Email: email, + } + + // Create TLS config + tlsConfig := m.TLSConfig() + tlsConfig.MinVersion = tls.VersionTLS13 + + return tlsConfig, nil +} + +// configureTLS configures TLS for the server based on the provided configuration. +// Returns the TLS config and certificate/key file paths (if applicable). +func configureTLS(cfg Config) (*tls.Config, string, string, error) { + // Option 1: Certificate files provided + if cfg.SSLCert != "" && cfg.SSLKey != "" { + // Validate that files exist + if _, err := os.Stat(cfg.SSLCert); os.IsNotExist(err) { + return nil, "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert) + } + if _, err := os.Stat(cfg.SSLKey); os.IsNotExist(err) { + return nil, "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey) + } + + // Return basic TLS config - cert/key will be loaded by ListenAndServeTLS + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + return tlsConfig, cfg.SSLCert, cfg.SSLKey, nil + } + + // Option 2: Auto TLS (Let's Encrypt) + if cfg.AutoTLS { + tlsConfig, err := setupAutoTLS(cfg.AutoTLSDomains, cfg.AutoTLSEmail, cfg.AutoTLSCacheDir) + if err != nil { + return nil, "", "", fmt.Errorf("failed to setup AutoTLS: %w", err) + } + return tlsConfig, "", "", nil + } + + // Option 3: Self-signed certificate + if cfg.SelfSignedSSL { + host := cfg.Host + if host == "" || host == "0.0.0.0" { + host = "localhost" + } + + // Sanitize hostname for safe file naming + safeHost := sanitizeHostname(host) + + // Lock to prevent concurrent certificate generation for the same host + certGenerationMutex.Lock() + defer certGenerationMutex.Unlock() + + // Get certificate directory + certDir, err := getCertDirectory() + if err != nil { + return nil, "", "", fmt.Errorf("failed to get certificate directory: %w", err) + } + + // Check for existing valid certificates + certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", safeHost)) + keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", safeHost)) + + // If valid certificates exist, reuse them + if isCertificateValid(certFile) { + // Verify key file also exists + if _, err := os.Stat(keyFile); err == nil { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + return tlsConfig, certFile, keyFile, nil + } + } + + // Generate new certificates + certPEM, keyPEM, err := generateSelfSignedCert(host) + if err != nil { + return nil, "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err) + } + + certFile, keyFile, err = saveCertToFiles(certPEM, keyPEM, host) + if err != nil { + return nil, "", "", fmt.Errorf("failed to save self-signed certificate: %w", err) + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + return tlsConfig, certFile, keyFile, nil + } + + return nil, "", "", nil +}