mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-09 13:04:24 +00:00
Work on server
This commit is contained in:
@@ -75,6 +75,25 @@ func CloseErrorTracking() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractContext attempts to find a context.Context in the given arguments.
|
||||||
|
// It returns the found context (or context.Background() if not found) and
|
||||||
|
// the remaining arguments without the context.
|
||||||
|
func extractContext(args ...interface{}) (context.Context, []interface{}) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var newArgs []interface{}
|
||||||
|
found := false
|
||||||
|
|
||||||
|
for _, arg := range args {
|
||||||
|
if c, ok := arg.(context.Context); ok && !found {
|
||||||
|
ctx = c
|
||||||
|
found = true
|
||||||
|
} else {
|
||||||
|
newArgs = append(newArgs, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, newArgs
|
||||||
|
}
|
||||||
|
|
||||||
func Info(template string, args ...interface{}) {
|
func Info(template string, args ...interface{}) {
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf(template, args...)
|
log.Printf(template, args...)
|
||||||
@@ -84,7 +103,8 @@ func Info(template string, args ...interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Warn(template string, args ...interface{}) {
|
func Warn(template string, args ...interface{}) {
|
||||||
message := fmt.Sprintf(template, args...)
|
ctx, remainingArgs := extractContext(args...)
|
||||||
|
message := fmt.Sprintf(template, remainingArgs...)
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf("%s", message)
|
log.Printf("%s", message)
|
||||||
} else {
|
} else {
|
||||||
@@ -93,14 +113,15 @@ func Warn(template string, args ...interface{}) {
|
|||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
|
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Error(template string, args ...interface{}) {
|
func Error(template string, args ...interface{}) {
|
||||||
message := fmt.Sprintf(template, args...)
|
ctx, remainingArgs := extractContext(args...)
|
||||||
|
message := fmt.Sprintf(template, remainingArgs...)
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf("%s", message)
|
log.Printf("%s", message)
|
||||||
} else {
|
} else {
|
||||||
@@ -109,7 +130,7 @@ func Error(template string, args ...interface{}) {
|
|||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
|
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -124,12 +145,13 @@ func Debug(template string, args ...interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CatchPanic - Handle panic
|
// CatchPanic - Handle panic
|
||||||
func CatchPanicCallback(location string, cb func(err any)) {
|
func CatchPanicCallback(location string, cb func(err any), args ...interface{}) {
|
||||||
|
ctx, _ := extractContext(args...)
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
callstack := debug.Stack()
|
callstack := debug.Stack()
|
||||||
|
|
||||||
if Logger != nil {
|
if Logger != nil {
|
||||||
Error("Panic in %s : %v", location, err)
|
Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||||
debug.PrintStack()
|
debug.PrintStack()
|
||||||
@@ -137,7 +159,7 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
|||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
|
errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{
|
||||||
"location": location,
|
"location": location,
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
@@ -150,8 +172,8 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CatchPanic - Handle panic
|
// CatchPanic - Handle panic
|
||||||
func CatchPanic(location string) {
|
func CatchPanic(location string, args ...interface{}) {
|
||||||
CatchPanicCallback(location, nil)
|
CatchPanicCallback(location, nil, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlePanic logs a panic and returns it as an error
|
// HandlePanic logs a panic and returns it as an error
|
||||||
@@ -163,13 +185,14 @@ func CatchPanic(location string) {
|
|||||||
// err = logger.HandlePanic("MethodName", r)
|
// err = logger.HandlePanic("MethodName", r)
|
||||||
// }
|
// }
|
||||||
// }()
|
// }()
|
||||||
func HandlePanic(methodName string, r any) error {
|
func HandlePanic(methodName string, r any, args ...interface{}) error {
|
||||||
|
ctx, _ := extractContext(args...)
|
||||||
stack := debug.Stack()
|
stack := debug.Stack()
|
||||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly
|
||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
|
errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{
|
||||||
"method": methodName,
|
"method": methodName,
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ type Provider interface {
|
|||||||
// UpdateEventQueueSize updates the event queue size metric
|
// UpdateEventQueueSize updates the event queue size metric
|
||||||
UpdateEventQueueSize(size int64)
|
UpdateEventQueueSize(size int64)
|
||||||
|
|
||||||
|
// RecordPanic records a panic event
|
||||||
|
RecordPanic(methodName string)
|
||||||
|
|
||||||
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
||||||
Handler() http.Handler
|
Handler() http.Handler
|
||||||
}
|
}
|
||||||
@@ -75,6 +78,7 @@ func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
|
|||||||
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||||
}
|
}
|
||||||
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
|
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
|
||||||
|
func (n *NoOpProvider) RecordPanic(methodName string) {}
|
||||||
func (n *NoOpProvider) Handler() http.Handler {
|
func (n *NoOpProvider) Handler() http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type PrometheusProvider struct {
|
|||||||
cacheHits *prometheus.CounterVec
|
cacheHits *prometheus.CounterVec
|
||||||
cacheMisses *prometheus.CounterVec
|
cacheMisses *prometheus.CounterVec
|
||||||
cacheSize *prometheus.GaugeVec
|
cacheSize *prometheus.GaugeVec
|
||||||
|
panicsTotal *prometheus.CounterVec
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||||
@@ -83,6 +84,13 @@ func NewPrometheusProvider() *PrometheusProvider {
|
|||||||
},
|
},
|
||||||
[]string{"provider"},
|
[]string{"provider"},
|
||||||
),
|
),
|
||||||
|
panicsTotal: promauto.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "panics_total",
|
||||||
|
Help: "Total number of panics",
|
||||||
|
},
|
||||||
|
[]string{"method"},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,6 +153,11 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
|
|||||||
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordPanic implements the Provider interface
|
||||||
|
func (p *PrometheusProvider) RecordPanic(methodName string) {
|
||||||
|
p.panicsTotal.WithLabelValues(methodName).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
// Handler implements Provider interface
|
// Handler implements Provider interface
|
||||||
func (p *PrometheusProvider) Handler() http.Handler {
|
func (p *PrometheusProvider) Handler() http.Handler {
|
||||||
return promhttp.Handler()
|
return promhttp.Handler()
|
||||||
|
|||||||
33
pkg/middleware/panic.go
Normal file
33
pkg/middleware/panic.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
const panicMiddlewareMethodName = "PanicMiddleware"
|
||||||
|
|
||||||
|
// PanicRecovery is a middleware that recovers from panics, logs the error,
|
||||||
|
// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error.
|
||||||
|
func PanicRecovery(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
if rcv := recover(); rcv != nil {
|
||||||
|
// Record the panic metric
|
||||||
|
metrics.GetProvider().RecordPanic(panicMiddlewareMethodName)
|
||||||
|
|
||||||
|
// Log the panic and send to error tracker
|
||||||
|
// We pass the request context so the error tracker can potentially
|
||||||
|
// link the panic to the request trace.
|
||||||
|
ctx := r.Context()
|
||||||
|
err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx)
|
||||||
|
|
||||||
|
// Respond with a 500 error
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
86
pkg/middleware/panic_test.go
Normal file
86
pkg/middleware/panic_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockMetricsProvider is a mock for the metrics provider to check if methods are called.
|
||||||
|
type mockMetricsProvider struct {
|
||||||
|
metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods
|
||||||
|
panicRecorded bool
|
||||||
|
methodName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMetricsProvider) RecordPanic(methodName string) {
|
||||||
|
m.panicRecorded = true
|
||||||
|
m.methodName = methodName
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPanicRecovery(t *testing.T) {
|
||||||
|
// Initialize a mock logger to avoid actual logging output during tests
|
||||||
|
logger.Init(true)
|
||||||
|
|
||||||
|
// Setup mock metrics provider
|
||||||
|
mockProvider := &mockMetricsProvider{}
|
||||||
|
originalProvider := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(mockProvider)
|
||||||
|
defer metrics.SetProvider(originalProvider) // Restore original provider after test
|
||||||
|
|
||||||
|
// 1. Test case: A handler that panics
|
||||||
|
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||||
|
// Reset mock state for this sub-test
|
||||||
|
mockProvider.panicRecorded = false
|
||||||
|
mockProvider.methodName = ""
|
||||||
|
|
||||||
|
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("something went terribly wrong")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the middleware wrapping the panicking handler
|
||||||
|
testHandler := PanicRecovery(panicHandler)
|
||||||
|
|
||||||
|
// Create a test request and response recorder
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Serve the request
|
||||||
|
testHandler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500")
|
||||||
|
assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body")
|
||||||
|
|
||||||
|
// Assert that the metric was recorded
|
||||||
|
assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider")
|
||||||
|
assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. Test case: A handler that does NOT panic
|
||||||
|
t.Run("does not interfere with a non-panicking handler", func(t *testing.T) {
|
||||||
|
// Reset mock state for this sub-test
|
||||||
|
mockProvider.panicRecorded = false
|
||||||
|
|
||||||
|
successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
testHandler := PanicRecovery(successHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
testHandler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200")
|
||||||
|
assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body")
|
||||||
|
assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic")
|
||||||
|
})
|
||||||
|
}
|
||||||
52
pkg/server/interfaces.go
Normal file
52
pkg/server/interfaces.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds the configuration for a single web server instance.
|
||||||
|
type Config struct {
|
||||||
|
Name string
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Description string
|
||||||
|
SSLCert string
|
||||||
|
SSLKey string
|
||||||
|
GZIP bool
|
||||||
|
// Handler is the http.Handler (e.g., a router) to be served.
|
||||||
|
Handler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instance defines the interface for a single server instance.
|
||||||
|
// It abstracts the underlying http.Server, allowing for easier management and testing.
|
||||||
|
type Instance interface {
|
||||||
|
// Start begins serving requests. This method should be non-blocking and
|
||||||
|
// run the server in a separate goroutine.
|
||||||
|
Start() error
|
||||||
|
// Stop gracefully shuts down the server without interrupting any active connections.
|
||||||
|
// It accepts a context to allow for a timeout.
|
||||||
|
Stop(ctx context.Context) error
|
||||||
|
// Addr returns the network address the server is listening on.
|
||||||
|
Addr() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager defines the interface for a server manager.
|
||||||
|
// It is responsible for managing the lifecycle of multiple server instances.
|
||||||
|
type Manager interface {
|
||||||
|
// Add registers a new server instance based on the provided configuration.
|
||||||
|
// The server is not started until StartAll or Start is called on the instance.
|
||||||
|
Add(cfg Config) (Instance, error)
|
||||||
|
// Get returns a server instance by its name.
|
||||||
|
Get(name string) (Instance, error)
|
||||||
|
// Remove stops and removes a server instance by its name.
|
||||||
|
Remove(name string) error
|
||||||
|
// StartAll starts all registered server instances that are not already running.
|
||||||
|
StartAll() error
|
||||||
|
// StopAll gracefully shuts down all running server instances.
|
||||||
|
StopAll() error
|
||||||
|
// RestartAll gracefully restarts all running server instances.
|
||||||
|
RestartAll() error
|
||||||
|
// List returns all registered server instances.
|
||||||
|
List() []Instance
|
||||||
|
}
|
||||||
282
pkg/server/manager.go
Normal file
282
pkg/server/manager.go
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
"github.com/klauspost/compress/gzhttp"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// serverManager manages a collection of server instances.
|
||||||
|
type serverManager struct {
|
||||||
|
instances map[string]Instance
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new server manager.
|
||||||
|
func NewManager() Manager {
|
||||||
|
return &serverManager{
|
||||||
|
instances: make(map[string]Instance),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add registers a new server instance.
|
||||||
|
func (sm *serverManager) Add(cfg Config) (Instance, error) {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
if cfg.Name == "" {
|
||||||
|
return nil, fmt.Errorf("server name cannot be empty")
|
||||||
|
}
|
||||||
|
if _, exists := sm.instances[cfg.Name]; exists {
|
||||||
|
return nil, fmt.Errorf("server with name '%s' already exists", cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
instance, err := newInstance(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sm.instances[cfg.Name] = instance
|
||||||
|
return instance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a server instance by its name.
|
||||||
|
func (sm *serverManager) Get(name string) (Instance, error) {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
instance, exists := sm.instances[name]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("server with name '%s' not found", name)
|
||||||
|
}
|
||||||
|
return instance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove stops and removes a server instance by its name.
|
||||||
|
func (sm *serverManager) Remove(name string) error {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
instance, exists := sm.instances[name]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("server with name '%s' not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the server if it's running
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := instance.Stop(ctx); err != nil {
|
||||||
|
logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err, context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(sm.instances, name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartAll starts all registered server instances.
|
||||||
|
func (sm *serverManager) StartAll() error {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
var startErrors []error
|
||||||
|
for name, instance := range sm.instances {
|
||||||
|
if err := instance.Start(); err != nil {
|
||||||
|
startErrors = append(startErrors, fmt.Errorf("failed to start server '%s': %w", name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(startErrors) > 0 {
|
||||||
|
// In a real-world scenario, you might want a more sophisticated error handling strategy
|
||||||
|
return fmt.Errorf("encountered errors while starting servers: %v", startErrors)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopAll gracefully shuts down all running server instances.
|
||||||
|
func (sm *serverManager) StopAll() error {
|
||||||
|
sm.mu.RLock()
|
||||||
|
instancesToStop := make([]Instance, 0, len(sm.instances))
|
||||||
|
for _, instance := range sm.instances {
|
||||||
|
instancesToStop = append(instancesToStop, instance)
|
||||||
|
}
|
||||||
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
|
logger.Info("Shutting down all servers...", context.Background())
|
||||||
|
|
||||||
|
var shutdownErrors []error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for _, instance := range instancesToStop {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(inst Instance) {
|
||||||
|
defer wg.Done()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := inst.Stop(ctx); err != nil {
|
||||||
|
shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Addr(), err))
|
||||||
|
}
|
||||||
|
}(instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if len(shutdownErrors) > 0 {
|
||||||
|
return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors)
|
||||||
|
}
|
||||||
|
logger.Info("All servers stopped gracefully.", context.Background())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestartAll gracefully restarts all running server instances.
|
||||||
|
func (sm *serverManager) RestartAll() error {
|
||||||
|
logger.Info("Restarting all servers...", context.Background())
|
||||||
|
if err := sm.StopAll(); err != nil {
|
||||||
|
return fmt.Errorf("failed to stop servers during restart: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give ports time to be released
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
if err := sm.StartAll(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start servers during restart: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("All servers restarted successfully.", context.Background())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all registered server instances.
|
||||||
|
func (sm *serverManager) List() []Instance {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
instances := make([]Instance, 0, len(sm.instances))
|
||||||
|
for _, instance := range sm.instances {
|
||||||
|
instances = append(instances, instance)
|
||||||
|
}
|
||||||
|
return instances
|
||||||
|
}
|
||||||
|
|
||||||
|
// serverInstance is a concrete implementation of the Instance interface.
|
||||||
|
type serverInstance struct {
|
||||||
|
cfg Config
|
||||||
|
httpServer *http.Server
|
||||||
|
mu sync.RWMutex
|
||||||
|
running bool
|
||||||
|
stopCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newInstance creates a new, unstarted server instance from a config.
|
||||||
|
func newInstance(cfg Config) (*serverInstance, error) {
|
||||||
|
if cfg.Handler == nil {
|
||||||
|
return nil, fmt.Errorf("handler cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||||
|
var handler http.Handler = cfg.Handler
|
||||||
|
|
||||||
|
// Wrap with GZIP handler if enabled
|
||||||
|
if cfg.GZIP {
|
||||||
|
gz, err := gzhttp.NewWrapper(gzhttp.BestSpeed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err)
|
||||||
|
}
|
||||||
|
handler = gz(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap with the panic recovery middleware
|
||||||
|
handler = middleware.PanicRecovery(handler)
|
||||||
|
|
||||||
|
// Here you could add other default middleware like request logging, metrics, etc.
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: 15 * time.Second,
|
||||||
|
WriteTimeout: 15 * time.Second,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &serverInstance{
|
||||||
|
cfg: cfg,
|
||||||
|
httpServer: httpServer,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins serving requests in a new goroutine.
|
||||||
|
func (s *serverInstance) Start() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.running {
|
||||||
|
return fmt.Errorf("server '%s' is already running", s.cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasSSL := s.cfg.SSLCert != "" && s.cfg.SSLKey != ""
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.running = false
|
||||||
|
s.mu.Unlock()
|
||||||
|
logger.Info("Server '%s' stopped.", s.cfg.Name, context.Background())
|
||||||
|
}()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
protocol := "HTTP"
|
||||||
|
|
||||||
|
if hasSSL {
|
||||||
|
protocol = "HTTPS"
|
||||||
|
// Configure TLS + HTTP/2
|
||||||
|
s.httpServer.TLSConfig = &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background()) err = s.httpServer.ListenAndServeTLS(s.cfg.SSLCert, s.cfg.SSLKey)
|
||||||
|
} else {
|
||||||
|
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr(), context.Background())
|
||||||
|
err = s.httpServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the server stopped for a reason other than a graceful shutdown, log the error.
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
logger.Error("Server '%s' failed: %v", s.cfg.Name, err, context.Background())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.running = true
|
||||||
|
// A small delay to allow the goroutine to start and potentially fail on binding.
|
||||||
|
// A more robust solution might involve a channel signal.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the server.
|
||||||
|
func (s *serverInstance) Stop(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if !s.running {
|
||||||
|
return nil // Already stopped
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name)
|
||||||
|
err := s.httpServer.Shutdown(ctx)
|
||||||
|
if err == nil {
|
||||||
|
s.running = false
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns the network address the server is listening on.
|
||||||
|
func (s *serverInstance) Addr() string {
|
||||||
|
return s.httpServer.Addr
|
||||||
|
}
|
||||||
125
pkg/server/manager_test.go
Normal file
125
pkg/server/manager_test.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getFreePort asks the kernel for a free open port that is ready to use.
|
||||||
|
func getFreePort(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
l, err := net.ListenTCP("tcp", addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer l.Close()
|
||||||
|
return l.Addr().(*net.TCPAddr).Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerManagerLifecycle(t *testing.T) {
|
||||||
|
// Initialize logger for test output
|
||||||
|
logger.Init(true)
|
||||||
|
|
||||||
|
// Create a new server manager
|
||||||
|
sm := NewManager()
|
||||||
|
|
||||||
|
// Define a simple test handler
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("Hello, World!"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get a free port for the server to listen on to avoid conflicts
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
|
||||||
|
// Add a new server configuration
|
||||||
|
serverConfig := Config{
|
||||||
|
Name: "TestServer",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: testPort,
|
||||||
|
Handler: testHandler,
|
||||||
|
}
|
||||||
|
instance, err := sm.Add(serverConfig)
|
||||||
|
require.NoError(t, err, "should be able to add a new server")
|
||||||
|
require.NotNil(t, instance, "added instance should not be nil")
|
||||||
|
|
||||||
|
// --- Test StartAll ---
|
||||||
|
err = sm.StartAll()
|
||||||
|
require.NoError(t, err, "StartAll should not return an error")
|
||||||
|
|
||||||
|
// Give the server a moment to start up
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// --- Verify Server is Running ---
|
||||||
|
client := &http.Client{Timeout: 2 * time.Second}
|
||||||
|
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||||
|
resp, err := client.Get(url)
|
||||||
|
require.NoError(t, err, "should be able to make a request to the running server")
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode, "expected status OK from the test server")
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp.Body.Close()
|
||||||
|
assert.Equal(t, "Hello, World!", string(body), "response body should match expected value")
|
||||||
|
|
||||||
|
// --- Test Get ---
|
||||||
|
retrievedInstance, err := sm.Get("TestServer")
|
||||||
|
require.NoError(t, err, "should be able to get server by name")
|
||||||
|
assert.Equal(t, instance.Addr(), retrievedInstance.Addr(), "retrieved instance should be the same")
|
||||||
|
|
||||||
|
// --- Test List ---
|
||||||
|
instanceList := sm.List()
|
||||||
|
require.Len(t, instanceList, 1, "list should contain one instance")
|
||||||
|
assert.Equal(t, instance.Addr(), instanceList[0].Addr(), "listed instance should be the same")
|
||||||
|
|
||||||
|
// --- Test StopAll ---
|
||||||
|
err = sm.StopAll()
|
||||||
|
require.NoError(t, err, "StopAll should not return an error")
|
||||||
|
|
||||||
|
// Give the server a moment to shut down
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// --- Verify Server is Stopped ---
|
||||||
|
_, err = client.Get(url)
|
||||||
|
require.Error(t, err, "should not be able to make a request to a stopped server")
|
||||||
|
|
||||||
|
// --- Test Remove ---
|
||||||
|
err = sm.Remove("TestServer")
|
||||||
|
require.NoError(t, err, "should be able to remove a server")
|
||||||
|
|
||||||
|
_, err = sm.Get("TestServer")
|
||||||
|
require.Error(t, err, "should not be able to get a removed server")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerErrorCases(t *testing.T) {
|
||||||
|
logger.Init(true)
|
||||||
|
sm := NewManager()
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
|
||||||
|
// --- Test Add Duplicate Name ---
|
||||||
|
config1 := Config{Name: "Duplicate", Host: "localhost", Port: testPort, Handler: http.NewServeMux()}
|
||||||
|
_, err := sm.Add(config1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config2 := Config{Name: "Duplicate", Host: "localhost", Port: getFreePort(t), Handler: http.NewServeMux()}
|
||||||
|
_, err = sm.Add(config2)
|
||||||
|
require.Error(t, err, "should not be able to add a server with a duplicate name")
|
||||||
|
|
||||||
|
// --- Test Get Non-existent ---
|
||||||
|
_, err = sm.Get("NonExistent")
|
||||||
|
require.Error(t, err, "should get an error for a non-existent server")
|
||||||
|
|
||||||
|
// --- Test Add with Nil Handler ---
|
||||||
|
config3 := Config{Name: "NilHandler", Host: "localhost", Port: getFreePort(t), Handler: nil}
|
||||||
|
_, err = sm.Add(config3)
|
||||||
|
require.Error(t, err, "should not be able to add a server with a nil handler")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user