diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index a033b1b..857eecd 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -75,6 +75,25 @@ func CloseErrorTracking() error { return nil } +// extractContext attempts to find a context.Context in the given arguments. +// It returns the found context (or context.Background() if not found) and +// the remaining arguments without the context. +func extractContext(args ...interface{}) (context.Context, []interface{}) { + ctx := context.Background() + var newArgs []interface{} + found := false + + for _, arg := range args { + if c, ok := arg.(context.Context); ok && !found { + ctx = c + found = true + } else { + newArgs = append(newArgs, arg) + } + } + return ctx, newArgs +} + func Info(template string, args ...interface{}) { if Logger == nil { log.Printf(template, args...) @@ -84,7 +103,8 @@ func Info(template string, args ...interface{}) { } func Warn(template string, args ...interface{}) { - message := fmt.Sprintf(template, args...) + ctx, remainingArgs := extractContext(args...) + message := fmt.Sprintf(template, remainingArgs...) if Logger == nil { log.Printf("%s", message) } else { @@ -93,14 +113,15 @@ func Warn(template string, args ...interface{}) { // Send to error tracker if errorTracker != nil { - errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{ + errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{ "process_id": os.Getpid(), }) } } func Error(template string, args ...interface{}) { - message := fmt.Sprintf(template, args...) + ctx, remainingArgs := extractContext(args...) + message := fmt.Sprintf(template, remainingArgs...) if Logger == nil { log.Printf("%s", message) } else { @@ -109,7 +130,7 @@ func Error(template string, args ...interface{}) { // Send to error tracker if errorTracker != nil { - errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{ + errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{ "process_id": os.Getpid(), }) } @@ -124,12 +145,13 @@ func Debug(template string, args ...interface{}) { } // CatchPanic - Handle panic -func CatchPanicCallback(location string, cb func(err any)) { +func CatchPanicCallback(location string, cb func(err any), args ...interface{}) { + ctx, _ := extractContext(args...) if err := recover(); err != nil { callstack := debug.Stack() if Logger != nil { - Error("Panic in %s : %v", location, err) + Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly } else { fmt.Printf("%s:PANIC->%+v", location, err) debug.PrintStack() @@ -137,7 +159,7 @@ func CatchPanicCallback(location string, cb func(err any)) { // Send to error tracker if errorTracker != nil { - errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{ + errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{ "location": location, "process_id": os.Getpid(), }) @@ -150,8 +172,8 @@ func CatchPanicCallback(location string, cb func(err any)) { } // CatchPanic - Handle panic -func CatchPanic(location string) { - CatchPanicCallback(location, nil) +func CatchPanic(location string, args ...interface{}) { + CatchPanicCallback(location, nil, args...) } // HandlePanic logs a panic and returns it as an error @@ -163,13 +185,14 @@ func CatchPanic(location string) { // err = logger.HandlePanic("MethodName", r) // } // }() -func HandlePanic(methodName string, r any) error { +func HandlePanic(methodName string, r any, args ...interface{}) error { + ctx, _ := extractContext(args...) stack := debug.Stack() - Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) + Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly // Send to error tracker if errorTracker != nil { - errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{ + errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{ "method": methodName, "process_id": os.Getpid(), }) diff --git a/pkg/metrics/interfaces.go b/pkg/metrics/interfaces.go index 9d11586..8b35463 100644 --- a/pkg/metrics/interfaces.go +++ b/pkg/metrics/interfaces.go @@ -39,6 +39,9 @@ type Provider interface { // UpdateEventQueueSize updates the event queue size metric UpdateEventQueueSize(size int64) + // RecordPanic records a panic event + RecordPanic(methodName string) + // Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint) Handler() http.Handler } @@ -75,6 +78,7 @@ func (n *NoOpProvider) RecordEventPublished(source, eventType string) {} func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) { } func (n *NoOpProvider) UpdateEventQueueSize(size int64) {} +func (n *NoOpProvider) RecordPanic(methodName string) {} func (n *NoOpProvider) Handler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) diff --git a/pkg/metrics/prometheus.go b/pkg/metrics/prometheus.go index 49c6ebb..9761a68 100644 --- a/pkg/metrics/prometheus.go +++ b/pkg/metrics/prometheus.go @@ -20,6 +20,7 @@ type PrometheusProvider struct { cacheHits *prometheus.CounterVec cacheMisses *prometheus.CounterVec cacheSize *prometheus.GaugeVec + panicsTotal *prometheus.CounterVec } // NewPrometheusProvider creates a new Prometheus metrics provider @@ -83,6 +84,13 @@ func NewPrometheusProvider() *PrometheusProvider { }, []string{"provider"}, ), + panicsTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "panics_total", + Help: "Total number of panics", + }, + []string{"method"}, + ), } } @@ -145,6 +153,11 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) { p.cacheSize.WithLabelValues(provider).Set(float64(size)) } +// RecordPanic implements the Provider interface +func (p *PrometheusProvider) RecordPanic(methodName string) { + p.panicsTotal.WithLabelValues(methodName).Inc() +} + // Handler implements Provider interface func (p *PrometheusProvider) Handler() http.Handler { return promhttp.Handler() diff --git a/pkg/middleware/panic.go b/pkg/middleware/panic.go new file mode 100644 index 0000000..fa36086 --- /dev/null +++ b/pkg/middleware/panic.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "net/http" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/metrics" +) + +const panicMiddlewareMethodName = "PanicMiddleware" + +// PanicRecovery is a middleware that recovers from panics, logs the error, +// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error. +func PanicRecovery(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rcv := recover(); rcv != nil { + // Record the panic metric + metrics.GetProvider().RecordPanic(panicMiddlewareMethodName) + + // Log the panic and send to error tracker + // We pass the request context so the error tracker can potentially + // link the panic to the request trace. + ctx := r.Context() + err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx) + + // Respond with a 500 error + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/middleware/panic_test.go b/pkg/middleware/panic_test.go new file mode 100644 index 0000000..e5cb77e --- /dev/null +++ b/pkg/middleware/panic_test.go @@ -0,0 +1,86 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/metrics" + "github.com/stretchr/testify/assert" +) + +// mockMetricsProvider is a mock for the metrics provider to check if methods are called. +type mockMetricsProvider struct { + metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods + panicRecorded bool + methodName string +} + +func (m *mockMetricsProvider) RecordPanic(methodName string) { + m.panicRecorded = true + m.methodName = methodName +} + +func TestPanicRecovery(t *testing.T) { + // Initialize a mock logger to avoid actual logging output during tests + logger.Init(true) + + // Setup mock metrics provider + mockProvider := &mockMetricsProvider{} + originalProvider := metrics.GetProvider() + metrics.SetProvider(mockProvider) + defer metrics.SetProvider(originalProvider) // Restore original provider after test + + // 1. Test case: A handler that panics + t.Run("recovers from panic and returns 500", func(t *testing.T) { + // Reset mock state for this sub-test + mockProvider.panicRecorded = false + mockProvider.methodName = "" + + panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("something went terribly wrong") + }) + + // Create the middleware wrapping the panicking handler + testHandler := PanicRecovery(panicHandler) + + // Create a test request and response recorder + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + // Serve the request + testHandler.ServeHTTP(rr, req) + + // Assertions + assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500") + assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body") + + // Assert that the metric was recorded + assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider") + assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name") + }) + + // 2. Test case: A handler that does NOT panic + t.Run("does not interfere with a non-panicking handler", func(t *testing.T) { + // Reset mock state for this sub-test + mockProvider.panicRecorded = false + + successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + testHandler := PanicRecovery(successHandler) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + testHandler.ServeHTTP(rr, req) + + // Assertions + assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200") + assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body") + assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic") + }) +} diff --git a/pkg/server/interfaces.go b/pkg/server/interfaces.go new file mode 100644 index 0000000..892cb58 --- /dev/null +++ b/pkg/server/interfaces.go @@ -0,0 +1,52 @@ +package server + +import ( + "context" + "net/http" +) + +// Config holds the configuration for a single web server instance. +type Config struct { + Name string + Host string + Port int + Description string + SSLCert string + SSLKey string + GZIP bool + // Handler is the http.Handler (e.g., a router) to be served. + Handler http.Handler +} + +// Instance defines the interface for a single server instance. +// It abstracts the underlying http.Server, allowing for easier management and testing. +type Instance interface { + // Start begins serving requests. This method should be non-blocking and + // run the server in a separate goroutine. + Start() error + // Stop gracefully shuts down the server without interrupting any active connections. + // It accepts a context to allow for a timeout. + Stop(ctx context.Context) error + // Addr returns the network address the server is listening on. + Addr() string +} + +// Manager defines the interface for a server manager. +// It is responsible for managing the lifecycle of multiple server instances. +type Manager interface { + // Add registers a new server instance based on the provided configuration. + // The server is not started until StartAll or Start is called on the instance. + Add(cfg Config) (Instance, error) + // Get returns a server instance by its name. + Get(name string) (Instance, error) + // Remove stops and removes a server instance by its name. + Remove(name string) error + // StartAll starts all registered server instances that are not already running. + StartAll() error + // StopAll gracefully shuts down all running server instances. + StopAll() error + // RestartAll gracefully restarts all running server instances. + RestartAll() error + // List returns all registered server instances. + List() []Instance +} diff --git a/pkg/server/manager.go b/pkg/server/manager.go new file mode 100644 index 0000000..005d8fa --- /dev/null +++ b/pkg/server/manager.go @@ -0,0 +1,282 @@ +package server + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "sync" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/middleware" + "github.com/klauspost/compress/gzhttp" + "golang.org/x/net/http2" +) + +// serverManager manages a collection of server instances. +type serverManager struct { + instances map[string]Instance + mu sync.RWMutex +} + +// NewManager creates a new server manager. +func NewManager() Manager { + return &serverManager{ + instances: make(map[string]Instance), + } +} + +// Add registers a new server instance. +func (sm *serverManager) Add(cfg Config) (Instance, error) { + sm.mu.Lock() + defer sm.mu.Unlock() + + if cfg.Name == "" { + return nil, fmt.Errorf("server name cannot be empty") + } + if _, exists := sm.instances[cfg.Name]; exists { + return nil, fmt.Errorf("server with name '%s' already exists", cfg.Name) + } + + instance, err := newInstance(cfg) + if err != nil { + return nil, err + } + + sm.instances[cfg.Name] = instance + return instance, nil +} + +// Get returns a server instance by its name. +func (sm *serverManager) Get(name string) (Instance, error) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + instance, exists := sm.instances[name] + if !exists { + return nil, fmt.Errorf("server with name '%s' not found", name) + } + return instance, nil +} + +// Remove stops and removes a server instance by its name. +func (sm *serverManager) Remove(name string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + instance, exists := sm.instances[name] + if !exists { + return fmt.Errorf("server with name '%s' not found", name) + } + + // Stop the server if it's running + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := instance.Stop(ctx); err != nil { + logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err, context.Background()) + } + + delete(sm.instances, name) + return nil +} + +// StartAll starts all registered server instances. +func (sm *serverManager) StartAll() error { + sm.mu.RLock() + defer sm.mu.RUnlock() + + var startErrors []error + for name, instance := range sm.instances { + if err := instance.Start(); err != nil { + startErrors = append(startErrors, fmt.Errorf("failed to start server '%s': %w", name, err)) + } + } + + if len(startErrors) > 0 { + // In a real-world scenario, you might want a more sophisticated error handling strategy + return fmt.Errorf("encountered errors while starting servers: %v", startErrors) + } + return nil +} + +// StopAll gracefully shuts down all running server instances. +func (sm *serverManager) StopAll() error { + sm.mu.RLock() + instancesToStop := make([]Instance, 0, len(sm.instances)) + for _, instance := range sm.instances { + instancesToStop = append(instancesToStop, instance) + } + sm.mu.RUnlock() + + logger.Info("Shutting down all servers...", context.Background()) + + var shutdownErrors []error + var wg sync.WaitGroup + + for _, instance := range instancesToStop { + wg.Add(1) + go func(inst Instance) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if err := inst.Stop(ctx); err != nil { + shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Addr(), err)) + } + }(instance) + } + + wg.Wait() + + if len(shutdownErrors) > 0 { + return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors) + } + logger.Info("All servers stopped gracefully.", context.Background()) + return nil +} + +// RestartAll gracefully restarts all running server instances. +func (sm *serverManager) RestartAll() error { + logger.Info("Restarting all servers...", context.Background()) + if err := sm.StopAll(); err != nil { + return fmt.Errorf("failed to stop servers during restart: %w", err) + } + + // Give ports time to be released + time.Sleep(200 * time.Millisecond) + + if err := sm.StartAll(); err != nil { + return fmt.Errorf("failed to start servers during restart: %w", err) + } + logger.Info("All servers restarted successfully.", context.Background()) + return nil +} + +// List returns all registered server instances. +func (sm *serverManager) List() []Instance { + sm.mu.RLock() + defer sm.mu.RUnlock() + + instances := make([]Instance, 0, len(sm.instances)) + for _, instance := range sm.instances { + instances = append(instances, instance) + } + return instances +} + +// serverInstance is a concrete implementation of the Instance interface. +type serverInstance struct { + cfg Config + httpServer *http.Server + mu sync.RWMutex + running bool + stopCh chan struct{} +} + +// newInstance creates a new, unstarted server instance from a config. +func newInstance(cfg Config) (*serverInstance, error) { + if cfg.Handler == nil { + return nil, fmt.Errorf("handler cannot be nil") + } + + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + var handler http.Handler = cfg.Handler + + // Wrap with GZIP handler if enabled + if cfg.GZIP { + gz, err := gzhttp.NewWrapper(gzhttp.BestSpeed) + if err != nil { + return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err) + } + handler = gz(handler) + } + + // Wrap with the panic recovery middleware + handler = middleware.PanicRecovery(handler) + + // Here you could add other default middleware like request logging, metrics, etc. + + httpServer := &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + } + + return &serverInstance{ + cfg: cfg, + httpServer: httpServer, + stopCh: make(chan struct{}), + }, nil +} + +// Start begins serving requests in a new goroutine. +func (s *serverInstance) Start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.running { + return fmt.Errorf("server '%s' is already running", s.cfg.Name) + } + + hasSSL := s.cfg.SSLCert != "" && s.cfg.SSLKey != "" + + go func() { + defer func() { + s.mu.Lock() + s.running = false + s.mu.Unlock() + logger.Info("Server '%s' stopped.", s.cfg.Name, context.Background()) + }() + + var err error + protocol := "HTTP" + + if hasSSL { + protocol = "HTTPS" + // Configure TLS + HTTP/2 + s.httpServer.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background()) err = s.httpServer.ListenAndServeTLS(s.cfg.SSLCert, s.cfg.SSLKey) + } else { + logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background()) + err = s.httpServer.ListenAndServe() + } + + // If the server stopped for a reason other than a graceful shutdown, log the error. + if err != nil && err != http.ErrServerClosed { + logger.Error("Server '%s' failed: %v", s.cfg.Name, err, context.Background()) + } + }() + + s.running = true + // A small delay to allow the goroutine to start and potentially fail on binding. + // A more robust solution might involve a channel signal. + time.Sleep(50 * time.Millisecond) + + return nil +} + +// Stop gracefully shuts down the server. +func (s *serverInstance) Stop(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.running { + return nil // Already stopped + } + + logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name) + err := s.httpServer.Shutdown(ctx) + if err == nil { + s.running = false + } + return err +} + +// Addr returns the network address the server is listening on. +func (s *serverInstance) Addr() string { + return s.httpServer.Addr +} diff --git a/pkg/server/manager_test.go b/pkg/server/manager_test.go new file mode 100644 index 0000000..1ff08f7 --- /dev/null +++ b/pkg/server/manager_test.go @@ -0,0 +1,125 @@ +package server + +import ( + "fmt" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// getFreePort asks the kernel for a free open port that is ready to use. +func getFreePort(t *testing.T) int { + t.Helper() + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + require.NoError(t, err) + + l, err := net.ListenTCP("tcp", addr) + require.NoError(t, err) + defer l.Close() + return l.Addr().(*net.TCPAddr).Port +} + +func TestServerManagerLifecycle(t *testing.T) { + // Initialize logger for test output + logger.Init(true) + + // Create a new server manager + sm := NewManager() + + // Define a simple test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello, World!")) + }) + + // Get a free port for the server to listen on to avoid conflicts + testPort := getFreePort(t) + + // Add a new server configuration + serverConfig := Config{ + Name: "TestServer", + Host: "localhost", + Port: testPort, + Handler: testHandler, + } + instance, err := sm.Add(serverConfig) + require.NoError(t, err, "should be able to add a new server") + require.NotNil(t, instance, "added instance should not be nil") + + // --- Test StartAll --- + err = sm.StartAll() + require.NoError(t, err, "StartAll should not return an error") + + // Give the server a moment to start up + time.Sleep(100 * time.Millisecond) + + // --- Verify Server is Running --- + client := &http.Client{Timeout: 2 * time.Second} + url := fmt.Sprintf("http://localhost:%d", testPort) + resp, err := client.Get(url) + require.NoError(t, err, "should be able to make a request to the running server") + + assert.Equal(t, http.StatusOK, resp.StatusCode, "expected status OK from the test server") + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, "Hello, World!", string(body), "response body should match expected value") + + // --- Test Get --- + retrievedInstance, err := sm.Get("TestServer") + require.NoError(t, err, "should be able to get server by name") + assert.Equal(t, instance.Addr(), retrievedInstance.Addr(), "retrieved instance should be the same") + + // --- Test List --- + instanceList := sm.List() + require.Len(t, instanceList, 1, "list should contain one instance") + assert.Equal(t, instance.Addr(), instanceList[0].Addr(), "listed instance should be the same") + + // --- Test StopAll --- + err = sm.StopAll() + require.NoError(t, err, "StopAll should not return an error") + + // Give the server a moment to shut down + time.Sleep(100 * time.Millisecond) + + // --- Verify Server is Stopped --- + _, err = client.Get(url) + require.Error(t, err, "should not be able to make a request to a stopped server") + + // --- Test Remove --- + err = sm.Remove("TestServer") + require.NoError(t, err, "should be able to remove a server") + + _, err = sm.Get("TestServer") + require.Error(t, err, "should not be able to get a removed server") +} + +func TestManagerErrorCases(t *testing.T) { + logger.Init(true) + sm := NewManager() + testPort := getFreePort(t) + + // --- Test Add Duplicate Name --- + config1 := Config{Name: "Duplicate", Host: "localhost", Port: testPort, Handler: http.NewServeMux()} + _, err := sm.Add(config1) + require.NoError(t, err) + + config2 := Config{Name: "Duplicate", Host: "localhost", Port: getFreePort(t), Handler: http.NewServeMux()} + _, err = sm.Add(config2) + require.Error(t, err, "should not be able to add a server with a duplicate name") + + // --- Test Get Non-existent --- + _, err = sm.Get("NonExistent") + require.Error(t, err, "should get an error for a non-existent server") + + // --- Test Add with Nil Handler --- + config3 := Config{Name: "NilHandler", Host: "localhost", Port: getFreePort(t), Handler: nil} + _, err = sm.Add(config3) + require.Error(t, err, "should not be able to add a server with a nil handler") +}