mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 08:14:25 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4a6f9c4c2 | ||
| 8f83e8fdc1 | |||
|
|
ed67caf055 | ||
| 4d1b8b6982 |
@@ -208,21 +208,9 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||
// If tableName is provided and the condition DOESN'T have a table prefix,
|
||||
// qualify unambiguous column references to prevent "ambiguous column" errors
|
||||
// when there are multiple joins on the same table (e.g., recursive preloads)
|
||||
columnName := extractUnqualifiedColumnName(condToCheck)
|
||||
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
|
||||
// Qualify the column with the table name
|
||||
// Be careful to only replace the column name, not other occurrences of the string
|
||||
oldRef := columnName
|
||||
newRef := tableName + "." + columnName
|
||||
// Use word boundary matching to avoid replacing partial matches
|
||||
cond = qualifyColumnInCondition(cond, oldRef, newRef)
|
||||
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
|
||||
}
|
||||
}
|
||||
// Note: We no longer add prefixes to unqualified columns here.
|
||||
// Use AddTablePrefixToColumns() separately if you need to add prefixes.
|
||||
|
||||
validConditions = append(validConditions, cond)
|
||||
}
|
||||
@@ -633,3 +621,145 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
}
|
||||
return validColumns[strings.ToLower(columnName)]
|
||||
}
|
||||
|
||||
// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause.
|
||||
// This function only prefixes simple column references and skips:
|
||||
// - Columns already having a table prefix (containing a dot)
|
||||
// - Columns inside function calls or expressions (inside parentheses)
|
||||
// - Columns inside subqueries
|
||||
// - Columns that don't exist in the table (validation via model registry)
|
||||
//
|
||||
// Examples:
|
||||
// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table)
|
||||
// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function)
|
||||
// - "users.status = 'active'" -> unchanged (already has prefix)
|
||||
// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK)
|
||||
// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table)
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause to process
|
||||
// - tableName: The table name to use as prefix
|
||||
//
|
||||
// Returns:
|
||||
// - The WHERE clause with table prefixes added to appropriate and valid columns
|
||||
func AddTablePrefixToColumns(where string, tableName string) string {
|
||||
if where == "" || tableName == "" {
|
||||
return where
|
||||
}
|
||||
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Get valid columns from the model registry for validation
|
||||
validColumns := getValidColumnsForTable(tableName)
|
||||
|
||||
// Split by AND to handle multiple conditions (parenthesis-aware)
|
||||
conditions := splitByAND(where)
|
||||
prefixedConditions := make([]string, 0, len(conditions))
|
||||
|
||||
for _, cond := range conditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process this condition to add table prefix if appropriate
|
||||
processedCond := addPrefixToSingleCondition(cond, tableName, validColumns)
|
||||
prefixedConditions = append(prefixedConditions, processedCond)
|
||||
}
|
||||
|
||||
if len(prefixedConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.Join(prefixedConditions, " AND ")
|
||||
}
|
||||
|
||||
// addPrefixToSingleCondition adds table prefix to a single condition if appropriate
|
||||
// Returns the condition unchanged if:
|
||||
// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.)
|
||||
// - The column reference is inside a function call
|
||||
// - The column already has a table prefix
|
||||
// - No valid column reference is found
|
||||
// - The column doesn't exist in the table (when validColumns is provided)
|
||||
func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string {
|
||||
// Strip outer grouping parentheses to get to the actual condition
|
||||
strippedCond := stripOuterParentheses(cond)
|
||||
|
||||
// Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.)
|
||||
if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) {
|
||||
logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond)
|
||||
return cond
|
||||
}
|
||||
|
||||
// Extract the left side of the comparison (before the operator)
|
||||
columnRef := extractLeftSideOfComparison(strippedCond)
|
||||
if columnRef == "" {
|
||||
return cond
|
||||
}
|
||||
|
||||
// Skip if it already has a prefix (contains a dot)
|
||||
if strings.Contains(columnRef, ".") {
|
||||
logger.Debug("Skipping column '%s' - already has table prefix", columnRef)
|
||||
return cond
|
||||
}
|
||||
|
||||
// Skip if it's a function call or expression (contains parentheses)
|
||||
if strings.Contains(columnRef, "(") {
|
||||
logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef)
|
||||
return cond
|
||||
}
|
||||
|
||||
// Validate that the column exists in the table (if we have column info)
|
||||
if !isValidColumn(columnRef, validColumns) {
|
||||
logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName)
|
||||
return cond
|
||||
}
|
||||
|
||||
// It's a simple unqualified column reference that exists in the table - add the table prefix
|
||||
newRef := tableName + "." + columnRef
|
||||
result := qualifyColumnInCondition(cond, columnRef, newRef)
|
||||
logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition.
|
||||
// This is used to identify the column reference that may need a table prefix.
|
||||
//
|
||||
// Examples:
|
||||
// - "status = 'active'" returns "status"
|
||||
// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')"
|
||||
// - "priority > 5" returns "priority"
|
||||
//
|
||||
// Returns empty string if no operator is found.
|
||||
func extractLeftSideOfComparison(cond string) string {
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||
|
||||
// Find the first operator outside of parentheses and quotes
|
||||
minIdx := -1
|
||||
for _, op := range operators {
|
||||
idx := findOperatorOutsideParentheses(cond, op)
|
||||
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||
minIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if minIdx > 0 {
|
||||
leftSide := strings.TrimSpace(cond[:minIdx])
|
||||
// Remove any surrounding quotes
|
||||
leftSide = strings.Trim(leftSide, "`\"'")
|
||||
return leftSide
|
||||
}
|
||||
|
||||
// No operator found - might be a boolean column
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef := strings.Trim(parts[0], "`\"'")
|
||||
// Make sure it's not a SQL keyword
|
||||
if !IsSQLKeyword(strings.ToLower(columnRef)) {
|
||||
return columnRef
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -75,6 +75,25 @@ func CloseErrorTracking() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractContext attempts to find a context.Context in the given arguments.
|
||||
// It returns the found context (or context.Background() if not found) and
|
||||
// the remaining arguments without the context.
|
||||
func extractContext(args ...interface{}) (context.Context, []interface{}) {
|
||||
ctx := context.Background()
|
||||
var newArgs []interface{}
|
||||
found := false
|
||||
|
||||
for _, arg := range args {
|
||||
if c, ok := arg.(context.Context); ok && !found {
|
||||
ctx = c
|
||||
found = true
|
||||
} else {
|
||||
newArgs = append(newArgs, arg)
|
||||
}
|
||||
}
|
||||
return ctx, newArgs
|
||||
}
|
||||
|
||||
func Info(template string, args ...interface{}) {
|
||||
if Logger == nil {
|
||||
log.Printf(template, args...)
|
||||
@@ -84,7 +103,8 @@ func Info(template string, args ...interface{}) {
|
||||
}
|
||||
|
||||
func Warn(template string, args ...interface{}) {
|
||||
message := fmt.Sprintf(template, args...)
|
||||
ctx, remainingArgs := extractContext(args...)
|
||||
message := fmt.Sprintf(template, remainingArgs...)
|
||||
if Logger == nil {
|
||||
log.Printf("%s", message)
|
||||
} else {
|
||||
@@ -93,14 +113,15 @@ func Warn(template string, args ...interface{}) {
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
|
||||
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Error(template string, args ...interface{}) {
|
||||
message := fmt.Sprintf(template, args...)
|
||||
ctx, remainingArgs := extractContext(args...)
|
||||
message := fmt.Sprintf(template, remainingArgs...)
|
||||
if Logger == nil {
|
||||
log.Printf("%s", message)
|
||||
} else {
|
||||
@@ -109,7 +130,7 @@ func Error(template string, args ...interface{}) {
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
|
||||
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
@@ -124,12 +145,13 @@ func Debug(template string, args ...interface{}) {
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanicCallback(location string, cb func(err any)) {
|
||||
func CatchPanicCallback(location string, cb func(err any), args ...interface{}) {
|
||||
ctx, _ := extractContext(args...)
|
||||
if err := recover(); err != nil {
|
||||
callstack := debug.Stack()
|
||||
|
||||
if Logger != nil {
|
||||
Error("Panic in %s : %v", location, err)
|
||||
Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly
|
||||
} else {
|
||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||
debug.PrintStack()
|
||||
@@ -137,7 +159,7 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
|
||||
errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{
|
||||
"location": location,
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
@@ -150,8 +172,8 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanic(location string) {
|
||||
CatchPanicCallback(location, nil)
|
||||
func CatchPanic(location string, args ...interface{}) {
|
||||
CatchPanicCallback(location, nil, args...)
|
||||
}
|
||||
|
||||
// HandlePanic logs a panic and returns it as an error
|
||||
@@ -163,13 +185,14 @@ func CatchPanic(location string) {
|
||||
// err = logger.HandlePanic("MethodName", r)
|
||||
// }
|
||||
// }()
|
||||
func HandlePanic(methodName string, r any) error {
|
||||
func HandlePanic(methodName string, r any, args ...interface{}) error {
|
||||
ctx, _ := extractContext(args...)
|
||||
stack := debug.Stack()
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
|
||||
errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{
|
||||
"method": methodName,
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
|
||||
@@ -39,6 +39,9 @@ type Provider interface {
|
||||
// UpdateEventQueueSize updates the event queue size metric
|
||||
UpdateEventQueueSize(size int64)
|
||||
|
||||
// RecordPanic records a panic event
|
||||
RecordPanic(methodName string)
|
||||
|
||||
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
||||
Handler() http.Handler
|
||||
}
|
||||
@@ -75,6 +78,7 @@ func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
|
||||
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||
}
|
||||
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
|
||||
func (n *NoOpProvider) RecordPanic(methodName string) {}
|
||||
func (n *NoOpProvider) Handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
||||
@@ -20,6 +20,7 @@ type PrometheusProvider struct {
|
||||
cacheHits *prometheus.CounterVec
|
||||
cacheMisses *prometheus.CounterVec
|
||||
cacheSize *prometheus.GaugeVec
|
||||
panicsTotal *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||
@@ -83,6 +84,13 @@ func NewPrometheusProvider() *PrometheusProvider {
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
panicsTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "panics_total",
|
||||
Help: "Total number of panics",
|
||||
},
|
||||
[]string{"method"},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,6 +153,11 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
|
||||
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
||||
}
|
||||
|
||||
// RecordPanic implements the Provider interface
|
||||
func (p *PrometheusProvider) RecordPanic(methodName string) {
|
||||
p.panicsTotal.WithLabelValues(methodName).Inc()
|
||||
}
|
||||
|
||||
// Handler implements Provider interface
|
||||
func (p *PrometheusProvider) Handler() http.Handler {
|
||||
return promhttp.Handler()
|
||||
|
||||
33
pkg/middleware/panic.go
Normal file
33
pkg/middleware/panic.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
const panicMiddlewareMethodName = "PanicMiddleware"
|
||||
|
||||
// PanicRecovery is a middleware that recovers from panics, logs the error,
|
||||
// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error.
|
||||
func PanicRecovery(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rcv := recover(); rcv != nil {
|
||||
// Record the panic metric
|
||||
metrics.GetProvider().RecordPanic(panicMiddlewareMethodName)
|
||||
|
||||
// Log the panic and send to error tracker
|
||||
// We pass the request context so the error tracker can potentially
|
||||
// link the panic to the request trace.
|
||||
ctx := r.Context()
|
||||
err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx)
|
||||
|
||||
// Respond with a 500 error
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
86
pkg/middleware/panic_test.go
Normal file
86
pkg/middleware/panic_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockMetricsProvider is a mock for the metrics provider to check if methods are called.
|
||||
type mockMetricsProvider struct {
|
||||
metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods
|
||||
panicRecorded bool
|
||||
methodName string
|
||||
}
|
||||
|
||||
func (m *mockMetricsProvider) RecordPanic(methodName string) {
|
||||
m.panicRecorded = true
|
||||
m.methodName = methodName
|
||||
}
|
||||
|
||||
func TestPanicRecovery(t *testing.T) {
|
||||
// Initialize a mock logger to avoid actual logging output during tests
|
||||
logger.Init(true)
|
||||
|
||||
// Setup mock metrics provider
|
||||
mockProvider := &mockMetricsProvider{}
|
||||
originalProvider := metrics.GetProvider()
|
||||
metrics.SetProvider(mockProvider)
|
||||
defer metrics.SetProvider(originalProvider) // Restore original provider after test
|
||||
|
||||
// 1. Test case: A handler that panics
|
||||
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||
// Reset mock state for this sub-test
|
||||
mockProvider.panicRecorded = false
|
||||
mockProvider.methodName = ""
|
||||
|
||||
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("something went terribly wrong")
|
||||
})
|
||||
|
||||
// Create the middleware wrapping the panicking handler
|
||||
testHandler := PanicRecovery(panicHandler)
|
||||
|
||||
// Create a test request and response recorder
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Assertions
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500")
|
||||
assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body")
|
||||
|
||||
// Assert that the metric was recorded
|
||||
assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider")
|
||||
assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name")
|
||||
})
|
||||
|
||||
// 2. Test case: A handler that does NOT panic
|
||||
t.Run("does not interfere with a non-panicking handler", func(t *testing.T) {
|
||||
// Reset mock state for this sub-test
|
||||
mockProvider.panicRecorded = false
|
||||
|
||||
successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
testHandler := PanicRecovery(successHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Assertions
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200")
|
||||
assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body")
|
||||
assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic")
|
||||
})
|
||||
}
|
||||
@@ -482,8 +482,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply custom SQL WHERE clause (AND condition)
|
||||
if options.CustomSQLWhere != "" {
|
||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
// First add table prefixes to unqualified columns (but skip columns inside function calls)
|
||||
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
if sanitizedWhere != "" {
|
||||
query = query.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -492,8 +494,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply custom SQL WHERE clause (OR condition)
|
||||
if options.CustomSQLOr != "" {
|
||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
if sanitizedOr != "" {
|
||||
query = query.WhereOr(sanitizedOr)
|
||||
}
|
||||
|
||||
@@ -1,233 +1,314 @@
|
||||
# Server Package
|
||||
|
||||
Graceful HTTP server with request draining and shutdown coordination.
|
||||
Production-ready HTTP server manager with graceful shutdown, request draining, and comprehensive TLS/HTTPS support.
|
||||
|
||||
## Features
|
||||
|
||||
✅ **Multiple Server Management** - Run multiple HTTP/HTTPS servers concurrently
|
||||
✅ **Graceful Shutdown** - Handles SIGINT/SIGTERM with request draining
|
||||
✅ **Automatic Request Rejection** - New requests get 503 during shutdown
|
||||
✅ **Health & Readiness Endpoints** - Kubernetes-ready health checks
|
||||
✅ **Shutdown Callbacks** - Register cleanup functions (DB, cache, metrics)
|
||||
✅ **Comprehensive TLS Support**:
|
||||
- Certificate files (production)
|
||||
- Self-signed certificates (development/testing)
|
||||
- Let's Encrypt / AutoTLS (automatic certificate management)
|
||||
✅ **GZIP Compression** - Optional response compression
|
||||
✅ **Panic Recovery** - Automatic panic recovery middleware
|
||||
✅ **Configurable Timeouts** - Read, write, idle, drain, and shutdown timeouts
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Single Server
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
|
||||
// Create server
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8080",
|
||||
Handler: router,
|
||||
// Create server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Add server
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: myRouter,
|
||||
GZIP: true,
|
||||
})
|
||||
|
||||
// Start server (blocks until shutdown signal)
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
// Start and wait for shutdown signal
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
### Multiple Servers
|
||||
|
||||
✅ Graceful shutdown on SIGINT/SIGTERM
|
||||
✅ Request draining (waits for in-flight requests)
|
||||
✅ Automatic request rejection during shutdown
|
||||
✅ Health and readiness endpoints
|
||||
✅ Shutdown callbacks for cleanup
|
||||
✅ Configurable timeouts
|
||||
```go
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Public API
|
||||
mgr.Add(server.Config{
|
||||
Name: "public-api",
|
||||
Port: 8080,
|
||||
Handler: publicRouter,
|
||||
})
|
||||
|
||||
// Admin API
|
||||
mgr.Add(server.Config{
|
||||
Name: "admin-api",
|
||||
Port: 8081,
|
||||
Handler: adminRouter,
|
||||
})
|
||||
|
||||
// Start all and wait
|
||||
mgr.ServeWithGracefulShutdown()
|
||||
```
|
||||
|
||||
## HTTPS/TLS Configuration
|
||||
|
||||
### Option 1: Certificate Files (Production)
|
||||
|
||||
```go
|
||||
mgr.Add(server.Config{
|
||||
Name: "https-server",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
SSLCert: "/etc/ssl/certs/server.crt",
|
||||
SSLKey: "/etc/ssl/private/server.key",
|
||||
})
|
||||
```
|
||||
|
||||
### Option 2: Self-Signed Certificate (Development)
|
||||
|
||||
```go
|
||||
mgr.Add(server.Config{
|
||||
Name: "dev-server",
|
||||
Host: "localhost",
|
||||
Port: 8443,
|
||||
Handler: handler,
|
||||
SelfSignedSSL: true, // Auto-generates certificate
|
||||
})
|
||||
```
|
||||
|
||||
### Option 3: Let's Encrypt / AutoTLS (Production)
|
||||
|
||||
```go
|
||||
mgr.Add(server.Config{
|
||||
Name: "prod-server",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
AutoTLSCacheDir: "./certs-cache", // Certificate cache directory
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
config := server.Config{
|
||||
// Server address
|
||||
Addr: ":8080",
|
||||
server.Config{
|
||||
// Basic configuration
|
||||
Name: "my-server", // Server name (required)
|
||||
Host: "0.0.0.0", // Bind address
|
||||
Port: 8080, // Port (required)
|
||||
Handler: myRouter, // HTTP handler (required)
|
||||
Description: "My API server", // Optional description
|
||||
|
||||
// HTTP handler
|
||||
Handler: myRouter,
|
||||
// Features
|
||||
GZIP: true, // Enable GZIP compression
|
||||
|
||||
// Maximum time for graceful shutdown (default: 30s)
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
// TLS/HTTPS (choose one option)
|
||||
SSLCert: "/path/to/cert.pem", // Certificate file
|
||||
SSLKey: "/path/to/key.pem", // Key file
|
||||
SelfSignedSSL: false, // Auto-generate self-signed cert
|
||||
AutoTLS: false, // Let's Encrypt
|
||||
AutoTLSDomains: []string{}, // Domains for AutoTLS
|
||||
AutoTLSEmail: "", // Email for Let's Encrypt
|
||||
AutoTLSCacheDir: "./certs-cache", // Cert cache directory
|
||||
|
||||
// Time to wait for in-flight requests (default: 25s)
|
||||
DrainTimeout: 25 * time.Second,
|
||||
|
||||
// Request read timeout (default: 10s)
|
||||
ReadTimeout: 10 * time.Second,
|
||||
|
||||
// Response write timeout (default: 10s)
|
||||
WriteTimeout: 10 * time.Second,
|
||||
|
||||
// Idle connection timeout (default: 120s)
|
||||
IdleTimeout: 120 * time.Second,
|
||||
// Timeouts
|
||||
ShutdownTimeout: 30 * time.Second, // Max shutdown time
|
||||
DrainTimeout: 25 * time.Second, // Request drain timeout
|
||||
ReadTimeout: 15 * time.Second, // Request read timeout
|
||||
WriteTimeout: 15 * time.Second, // Response write timeout
|
||||
IdleTimeout: 60 * time.Second, // Idle connection timeout
|
||||
}
|
||||
|
||||
srv := server.NewGracefulServer(config)
|
||||
```
|
||||
|
||||
## Shutdown Behavior
|
||||
## Graceful Shutdown
|
||||
|
||||
**Signal received (SIGINT/SIGTERM):**
|
||||
### Automatic (Recommended)
|
||||
|
||||
1. **Mark as shutting down** - New requests get 503
|
||||
2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests
|
||||
3. **Shutdown server** - Close listeners and connections
|
||||
4. **Execute callbacks** - Run registered cleanup functions
|
||||
```go
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Add servers...
|
||||
|
||||
// This blocks until SIGINT/SIGTERM
|
||||
mgr.ServeWithGracefulShutdown()
|
||||
```
|
||||
Time Event
|
||||
─────────────────────────────────────────
|
||||
0s Signal received: SIGTERM
|
||||
├─ Mark as shutting down
|
||||
├─ Reject new requests (503)
|
||||
└─ Start draining...
|
||||
|
||||
1s In-flight: 50 requests
|
||||
2s In-flight: 32 requests
|
||||
3s In-flight: 12 requests
|
||||
4s In-flight: 3 requests
|
||||
5s In-flight: 0 requests ✓
|
||||
└─ All requests drained
|
||||
### Manual Control
|
||||
|
||||
5s Execute shutdown callbacks
|
||||
6s Shutdown complete
|
||||
```go
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Add and start servers
|
||||
mgr.StartAll()
|
||||
|
||||
// Later... stop gracefully
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := mgr.StopAllWithContext(ctx); err != nil {
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### Shutdown Callbacks
|
||||
|
||||
Register cleanup functions to run during shutdown:
|
||||
|
||||
```go
|
||||
// Close database
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Closing database...")
|
||||
return db.Close()
|
||||
})
|
||||
|
||||
// Flush metrics
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Flushing metrics...")
|
||||
return metrics.Flush(ctx)
|
||||
})
|
||||
|
||||
// Close cache
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Closing cache...")
|
||||
return cache.Close()
|
||||
})
|
||||
```
|
||||
|
||||
## Health Checks
|
||||
|
||||
### Health Endpoint
|
||||
|
||||
Returns 200 when healthy, 503 when shutting down:
|
||||
### Adding Health Endpoints
|
||||
|
||||
```go
|
||||
router.HandleFunc("/health", srv.HealthCheckHandler())
|
||||
instance, _ := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Port: 8080,
|
||||
Handler: router,
|
||||
})
|
||||
|
||||
// Add health endpoints to your router
|
||||
router.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
router.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
```
|
||||
|
||||
**Response (healthy):**
|
||||
### Health Endpoint
|
||||
|
||||
Returns server health status:
|
||||
|
||||
**Healthy (200 OK):**
|
||||
```json
|
||||
{"status":"healthy"}
|
||||
```
|
||||
|
||||
**Response (shutting down):**
|
||||
**Shutting Down (503 Service Unavailable):**
|
||||
```json
|
||||
{"status":"shutting_down"}
|
||||
```
|
||||
|
||||
### Readiness Endpoint
|
||||
|
||||
Includes in-flight request count:
|
||||
Returns readiness with in-flight request count:
|
||||
|
||||
```go
|
||||
router.HandleFunc("/ready", srv.ReadinessHandler())
|
||||
```
|
||||
|
||||
**Response:**
|
||||
**Ready (200 OK):**
|
||||
```json
|
||||
{"ready":true,"in_flight_requests":12}
|
||||
```
|
||||
|
||||
**During shutdown:**
|
||||
**Not Ready (503 Service Unavailable):**
|
||||
```json
|
||||
{"ready":false,"reason":"shutting_down"}
|
||||
```
|
||||
|
||||
## Shutdown Callbacks
|
||||
## Shutdown Behavior
|
||||
|
||||
Register cleanup functions to run during shutdown:
|
||||
When a shutdown signal (SIGINT/SIGTERM) is received:
|
||||
|
||||
```go
|
||||
// Close database
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
logger.Info("Closing database connection...")
|
||||
return db.Close()
|
||||
})
|
||||
1. **Mark as shutting down** → New requests get 503
|
||||
2. **Execute callbacks** → Run cleanup functions
|
||||
3. **Drain requests** → Wait up to `DrainTimeout` for in-flight requests
|
||||
4. **Shutdown servers** → Close listeners and connections
|
||||
|
||||
// Flush metrics
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
logger.Info("Flushing metrics...")
|
||||
return metricsProvider.Flush(ctx)
|
||||
})
|
||||
```
|
||||
Time Event
|
||||
─────────────────────────────────────────
|
||||
0s Signal received: SIGTERM
|
||||
├─ Mark servers as shutting down
|
||||
├─ Reject new requests (503)
|
||||
└─ Execute shutdown callbacks
|
||||
|
||||
// Close cache
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
logger.Info("Closing cache...")
|
||||
return cache.Close()
|
||||
})
|
||||
1s Callbacks complete
|
||||
└─ Start draining requests...
|
||||
|
||||
2s In-flight: 50 requests
|
||||
3s In-flight: 32 requests
|
||||
4s In-flight: 12 requests
|
||||
5s In-flight: 3 requests
|
||||
6s In-flight: 0 requests ✓
|
||||
└─ All requests drained
|
||||
|
||||
6s Shutdown servers
|
||||
7s All servers stopped ✓
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
## Server Management
|
||||
|
||||
### Get Server Instance
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize metrics
|
||||
metricsProvider := metrics.NewPrometheusProvider()
|
||||
metrics.SetProvider(metricsProvider)
|
||||
|
||||
// Create router
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Apply middleware
|
||||
rateLimiter := middleware.NewRateLimiter(100, 20)
|
||||
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB)
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
|
||||
router.Use(rateLimiter.Middleware)
|
||||
router.Use(sizeLimiter.Middleware)
|
||||
router.Use(sanitizer.Middleware)
|
||||
router.Use(metricsProvider.Middleware)
|
||||
|
||||
// API routes
|
||||
router.HandleFunc("/api/data", dataHandler)
|
||||
|
||||
// Create graceful server
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8080",
|
||||
Handler: router,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
DrainTimeout: 25 * time.Second,
|
||||
})
|
||||
|
||||
// Health checks
|
||||
router.HandleFunc("/health", srv.HealthCheckHandler())
|
||||
router.HandleFunc("/ready", srv.ReadinessHandler())
|
||||
|
||||
// Metrics endpoint
|
||||
router.Handle("/metrics", metricsProvider.Handler())
|
||||
|
||||
// Register shutdown callbacks
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Cleanup: Flushing metrics...")
|
||||
return nil
|
||||
})
|
||||
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Cleanup: Closing database...")
|
||||
// return db.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start server (blocks until shutdown)
|
||||
log.Printf("Starting server on :8080")
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
instance, err := mgr.Get("api-server")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for shutdown to complete
|
||||
srv.Wait()
|
||||
log.Println("Server stopped")
|
||||
}
|
||||
// Check status
|
||||
fmt.Printf("Address: %s\n", instance.Addr())
|
||||
fmt.Printf("Name: %s\n", instance.Name())
|
||||
fmt.Printf("In-flight: %d\n", instance.InFlightRequests())
|
||||
fmt.Printf("Shutting down: %v\n", instance.IsShuttingDown())
|
||||
```
|
||||
|
||||
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Your handler logic
|
||||
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message":"success"}`))
|
||||
### List All Servers
|
||||
|
||||
```go
|
||||
instances := mgr.List()
|
||||
for _, instance := range instances {
|
||||
fmt.Printf("Server: %s at %s\n", instance.Name(), instance.Addr())
|
||||
}
|
||||
```
|
||||
|
||||
### Remove Server
|
||||
|
||||
```go
|
||||
// Stop and remove a server
|
||||
if err := mgr.Remove("api-server"); err != nil {
|
||||
log.Printf("Error removing server: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### Restart All Servers
|
||||
|
||||
```go
|
||||
// Gracefully restart all servers
|
||||
if err := mgr.RestartAll(); err != nil {
|
||||
log.Printf("Error restarting: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
@@ -250,23 +331,21 @@ spec:
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
|
||||
# Liveness probe - is app running?
|
||||
# Liveness probe
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8080
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
timeoutSeconds: 5
|
||||
|
||||
# Readiness probe - can app handle traffic?
|
||||
# Readiness probe
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /ready
|
||||
port: 8080
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 5
|
||||
timeoutSeconds: 3
|
||||
|
||||
# Graceful shutdown
|
||||
lifecycle:
|
||||
@@ -274,26 +353,12 @@ spec:
|
||||
exec:
|
||||
command: ["/bin/sh", "-c", "sleep 5"]
|
||||
|
||||
# Environment
|
||||
env:
|
||||
- name: SHUTDOWN_TIMEOUT
|
||||
value: "30"
|
||||
```
|
||||
|
||||
### Service
|
||||
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: myapp
|
||||
spec:
|
||||
selector:
|
||||
app: myapp
|
||||
ports:
|
||||
- port: 80
|
||||
targetPort: 8080
|
||||
type: LoadBalancer
|
||||
# Allow time for graceful shutdown
|
||||
terminationGracePeriodSeconds: 35
|
||||
```
|
||||
|
||||
## Docker Compose
|
||||
@@ -312,8 +377,70 @@ services:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
stop_grace_period: 35s # Slightly longer than shutdown timeout
|
||||
stop_grace_period: 35s
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Register shutdown callbacks
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Cleanup: Closing database...")
|
||||
// return db.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create router
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("/api/data", dataHandler)
|
||||
|
||||
// Add server
|
||||
instance, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
Handler: router,
|
||||
GZIP: true,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
DrainTimeout: 25 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Add health endpoints
|
||||
router.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
router.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
|
||||
// Start and wait for shutdown
|
||||
log.Println("Starting server on :8080")
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
log.Printf("Server stopped: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server shutdown complete")
|
||||
}
|
||||
|
||||
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message":"success"}`))
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Graceful Shutdown
|
||||
@@ -330,7 +457,7 @@ SERVER_PID=$!
|
||||
# Wait for server to start
|
||||
sleep 2
|
||||
|
||||
# Send some requests
|
||||
# Send requests
|
||||
for i in {1..10}; do
|
||||
curl http://localhost:8080/api/data &
|
||||
done
|
||||
@@ -341,7 +468,7 @@ sleep 1
|
||||
# Send shutdown signal
|
||||
kill -TERM $SERVER_PID
|
||||
|
||||
# Try to send more requests (should get 503)
|
||||
# Try more requests (should get 503)
|
||||
curl -v http://localhost:8080/api/data
|
||||
|
||||
# Wait for server to stop
|
||||
@@ -349,101 +476,13 @@ wait $SERVER_PID
|
||||
echo "Server stopped gracefully"
|
||||
```
|
||||
|
||||
### Expected Output
|
||||
|
||||
```
|
||||
Starting server on :8080
|
||||
Received signal: terminated, initiating graceful shutdown
|
||||
Starting graceful shutdown...
|
||||
Waiting for 8 in-flight requests to complete...
|
||||
Waiting for 4 in-flight requests to complete...
|
||||
Waiting for 1 in-flight requests to complete...
|
||||
All requests drained in 2.3s
|
||||
Cleanup: Flushing metrics...
|
||||
Cleanup: Closing database...
|
||||
Shutting down HTTP server...
|
||||
Graceful shutdown complete
|
||||
Server stopped
|
||||
```
|
||||
|
||||
## Monitoring In-Flight Requests
|
||||
|
||||
```go
|
||||
// Get current in-flight count
|
||||
count := srv.InFlightRequests()
|
||||
fmt.Printf("In-flight requests: %d\n", count)
|
||||
|
||||
// Check if shutting down
|
||||
if srv.IsShuttingDown() {
|
||||
fmt.Println("Server is shutting down")
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Shutdown Logic
|
||||
|
||||
```go
|
||||
// Implement custom shutdown
|
||||
go func() {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
<-sigChan
|
||||
log.Println("Shutdown signal received")
|
||||
|
||||
// Custom pre-shutdown logic
|
||||
log.Println("Running custom cleanup...")
|
||||
|
||||
// Shutdown with callbacks
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.ShutdownWithCallbacks(ctx); err != nil {
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start server
|
||||
srv.server.ListenAndServe()
|
||||
```
|
||||
|
||||
### Multiple Servers
|
||||
|
||||
```go
|
||||
// HTTP server
|
||||
httpSrv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8080",
|
||||
Handler: httpRouter,
|
||||
})
|
||||
|
||||
// HTTPS server
|
||||
httpsSrv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8443",
|
||||
Handler: httpsRouter,
|
||||
})
|
||||
|
||||
// Start both
|
||||
go httpSrv.ListenAndServe()
|
||||
go httpsSrv.ListenAndServe()
|
||||
|
||||
// Shutdown both on signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
ctx := context.Background()
|
||||
httpSrv.Shutdown(ctx)
|
||||
httpsSrv.Shutdown(ctx)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Set appropriate timeouts**
|
||||
- `DrainTimeout` < `ShutdownTimeout`
|
||||
- `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds`
|
||||
|
||||
2. **Register cleanup callbacks** for:
|
||||
2. **Use shutdown callbacks** for:
|
||||
- Database connections
|
||||
- Message queues
|
||||
- Metrics flushing
|
||||
@@ -458,7 +497,12 @@ httpsSrv.Shutdown(ctx)
|
||||
- Set `preStop` hook in Kubernetes (5-10s delay)
|
||||
- Allows load balancer to deregister before shutdown
|
||||
|
||||
5. **Monitoring**
|
||||
5. **HTTPS in production**
|
||||
- Use AutoTLS for public-facing services
|
||||
- Use certificate files for enterprise PKI
|
||||
- Use self-signed only for development/testing
|
||||
|
||||
6. **Monitoring**
|
||||
- Track in-flight requests in metrics
|
||||
- Alert on slow drains
|
||||
- Monitor shutdown duration
|
||||
@@ -470,24 +514,63 @@ httpsSrv.Shutdown(ctx)
|
||||
```go
|
||||
// Increase drain timeout
|
||||
config.DrainTimeout = 60 * time.Second
|
||||
config.ShutdownTimeout = 65 * time.Second
|
||||
```
|
||||
|
||||
### Requests Still Timing Out
|
||||
### Requests Timing Out
|
||||
|
||||
```go
|
||||
// Increase write timeout
|
||||
config.WriteTimeout = 30 * time.Second
|
||||
```
|
||||
|
||||
### Force Shutdown Not Working
|
||||
|
||||
The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed.
|
||||
|
||||
### Debugging Shutdown
|
||||
### Certificate Issues
|
||||
|
||||
```go
|
||||
// Verify certificate files exist and are readable
|
||||
if _, err := os.Stat(config.SSLCert); err != nil {
|
||||
log.Fatalf("Certificate not found: %v", err)
|
||||
}
|
||||
|
||||
// For AutoTLS, ensure:
|
||||
// - Port 443 is accessible
|
||||
// - Domains resolve to server IP
|
||||
// - Cache directory is writable
|
||||
```
|
||||
|
||||
### Debug Logging
|
||||
|
||||
```go
|
||||
// Enable debug logging
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
|
||||
// Enable debug logging
|
||||
logger.SetLevel("debug")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Manager Methods
|
||||
|
||||
- `NewManager()` - Create new server manager
|
||||
- `Add(cfg Config)` - Register server instance
|
||||
- `Get(name string)` - Get server by name
|
||||
- `Remove(name string)` - Stop and remove server
|
||||
- `StartAll()` - Start all registered servers
|
||||
- `StopAll()` - Stop all servers gracefully
|
||||
- `StopAllWithContext(ctx)` - Stop with timeout
|
||||
- `RestartAll()` - Restart all servers
|
||||
- `List()` - Get all server instances
|
||||
- `ServeWithGracefulShutdown()` - Start and block until shutdown
|
||||
- `RegisterShutdownCallback(cb)` - Register cleanup function
|
||||
|
||||
### Instance Methods
|
||||
|
||||
- `Start()` - Start the server
|
||||
- `Stop(ctx)` - Stop gracefully
|
||||
- `Addr()` - Get server address
|
||||
- `Name()` - Get server name
|
||||
- `HealthCheckHandler()` - Get health handler
|
||||
- `ReadinessHandler()` - Get readiness handler
|
||||
- `InFlightRequests()` - Get in-flight count
|
||||
- `IsShuttingDown()` - Check shutdown status
|
||||
- `Wait()` - Block until shutdown complete
|
||||
|
||||
294
pkg/server/example_test.go
Normal file
294
pkg/server/example_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
)
|
||||
|
||||
// ExampleManager_basic demonstrates basic server manager usage
|
||||
func ExampleManager_basic() {
|
||||
// Create a server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Define a simple handler
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintln(w, "Hello from server!")
|
||||
})
|
||||
|
||||
// Add an HTTP server
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: handler,
|
||||
GZIP: true, // Enable GZIP compression
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Server is now running...
|
||||
// When done, stop gracefully
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleManager_https demonstrates HTTPS configurations
|
||||
func ExampleManager_https() {
|
||||
mgr := server.NewManager()
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Secure connection!")
|
||||
})
|
||||
|
||||
// Option 1: Use certificate files
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "https-server-files",
|
||||
Host: "localhost",
|
||||
Port: 8443,
|
||||
Handler: handler,
|
||||
SSLCert: "/path/to/cert.pem",
|
||||
SSLKey: "/path/to/key.pem",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Option 2: Self-signed certificate (for development)
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "https-server-self-signed",
|
||||
Host: "localhost",
|
||||
Port: 8444,
|
||||
Handler: handler,
|
||||
SelfSignedSSL: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Option 3: Let's Encrypt / AutoTLS (for production)
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "https-server-letsencrypt",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
AutoTLSCacheDir: "./certs-cache",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
mgr.StopAll()
|
||||
}
|
||||
|
||||
// ExampleManager_gracefulShutdown demonstrates graceful shutdown with callbacks
|
||||
func ExampleManager_gracefulShutdown() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Register shutdown callbacks for cleanup tasks
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
fmt.Println("Closing database connections...")
|
||||
// Close your database here
|
||||
return nil
|
||||
})
|
||||
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
fmt.Println("Flushing metrics...")
|
||||
// Flush metrics here
|
||||
return nil
|
||||
})
|
||||
|
||||
// Add server with custom timeouts
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate some work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
fmt.Fprintln(w, "Done!")
|
||||
})
|
||||
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: handler,
|
||||
ShutdownTimeout: 30 * time.Second, // Max time for shutdown
|
||||
DrainTimeout: 25 * time.Second, // Time to wait for in-flight requests
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
|
||||
// This will automatically handle graceful shutdown with callbacks
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
fmt.Printf("Shutdown completed: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleManager_healthChecks demonstrates health and readiness endpoints
|
||||
func ExampleManager_healthChecks() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Create a router with health endpoints
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Data endpoint")
|
||||
})
|
||||
|
||||
// Add server
|
||||
instance, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: mux,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Add health and readiness endpoints
|
||||
mux.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
mux.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
|
||||
// Start the server
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Health check returns:
|
||||
// - 200 OK with {"status":"healthy"} when healthy
|
||||
// - 503 Service Unavailable with {"status":"shutting_down"} when shutting down
|
||||
|
||||
// Readiness check returns:
|
||||
// - 200 OK with {"ready":true,"in_flight_requests":N} when ready
|
||||
// - 503 Service Unavailable with {"ready":false,"reason":"shutting_down"} when shutting down
|
||||
|
||||
// Cleanup
|
||||
mgr.StopAll()
|
||||
}
|
||||
|
||||
// ExampleManager_multipleServers demonstrates running multiple servers
|
||||
func ExampleManager_multipleServers() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Public API server
|
||||
publicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Public API")
|
||||
})
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "public-api",
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
Handler: publicHandler,
|
||||
GZIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Admin API server (different port)
|
||||
adminHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Admin API")
|
||||
})
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "admin-api",
|
||||
Host: "localhost",
|
||||
Port: 8081,
|
||||
Handler: adminHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Metrics server (internal only)
|
||||
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Metrics data")
|
||||
})
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "metrics",
|
||||
Host: "127.0.0.1",
|
||||
Port: 9090,
|
||||
Handler: metricsHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start all servers at once
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Get specific server instance
|
||||
publicInstance, err := mgr.Get("public-api")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
|
||||
|
||||
// List all servers
|
||||
instances := mgr.List()
|
||||
fmt.Printf("Running %d servers\n", len(instances))
|
||||
|
||||
// Stop all servers gracefully (in parallel)
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleManager_monitoring demonstrates monitoring server state
|
||||
func ExampleManager_monitoring() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||
fmt.Fprintln(w, "Done")
|
||||
})
|
||||
|
||||
instance, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: handler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Check server status
|
||||
fmt.Printf("Server address: %s\n", instance.Addr())
|
||||
fmt.Printf("Server name: %s\n", instance.Name())
|
||||
fmt.Printf("Is shutting down: %v\n", instance.IsShuttingDown())
|
||||
fmt.Printf("In-flight requests: %d\n", instance.InFlightRequests())
|
||||
|
||||
// Cleanup
|
||||
mgr.StopAll()
|
||||
|
||||
// Wait for complete shutdown
|
||||
instance.Wait()
|
||||
}
|
||||
137
pkg/server/interfaces.go
Normal file
137
pkg/server/interfaces.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds the configuration for a single web server instance.
|
||||
type Config struct {
|
||||
Name string
|
||||
Host string
|
||||
Port int
|
||||
Description string
|
||||
|
||||
// Handler is the http.Handler (e.g., a router) to be served.
|
||||
Handler http.Handler
|
||||
|
||||
// GZIP compression support
|
||||
GZIP bool
|
||||
|
||||
// TLS/HTTPS configuration options (mutually exclusive)
|
||||
// Option 1: Provide certificate and key files directly
|
||||
SSLCert string
|
||||
SSLKey string
|
||||
|
||||
// Option 2: Use self-signed certificate (for development/testing)
|
||||
// Generates a self-signed certificate automatically if no SSLCert/SSLKey provided
|
||||
SelfSignedSSL bool
|
||||
|
||||
// Option 3: Use Let's Encrypt / Certbot for automatic TLS
|
||||
// AutoTLS enables automatic certificate management via Let's Encrypt
|
||||
AutoTLS bool
|
||||
// AutoTLSDomains specifies the domains for Let's Encrypt certificates
|
||||
AutoTLSDomains []string
|
||||
// AutoTLSCacheDir specifies where to cache certificates (default: "./certs-cache")
|
||||
AutoTLSCacheDir string
|
||||
// AutoTLSEmail is the email for Let's Encrypt registration (optional but recommended)
|
||||
AutoTLSEmail string
|
||||
|
||||
// Graceful shutdown configuration
|
||||
// ShutdownTimeout is the maximum time to wait for graceful shutdown
|
||||
// Default: 30 seconds
|
||||
ShutdownTimeout time.Duration
|
||||
|
||||
// DrainTimeout is the time to wait for in-flight requests to complete
|
||||
// before forcing shutdown. Default: 25 seconds
|
||||
DrainTimeout time.Duration
|
||||
|
||||
// ReadTimeout is the maximum duration for reading the entire request
|
||||
// Default: 15 seconds
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// WriteTimeout is the maximum duration before timing out writes of the response
|
||||
// Default: 15 seconds
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// IdleTimeout is the maximum amount of time to wait for the next request
|
||||
// Default: 60 seconds
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// Instance defines the interface for a single server instance.
|
||||
// It abstracts the underlying http.Server, allowing for easier management and testing.
|
||||
type Instance interface {
|
||||
// Start begins serving requests. This method should be non-blocking and
|
||||
// run the server in a separate goroutine.
|
||||
Start() error
|
||||
|
||||
// Stop gracefully shuts down the server without interrupting any active connections.
|
||||
// It accepts a context to allow for a timeout.
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// Addr returns the network address the server is listening on.
|
||||
Addr() string
|
||||
|
||||
// Name returns the server instance name.
|
||||
Name() string
|
||||
|
||||
// HealthCheckHandler returns a handler that responds to health checks.
|
||||
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down.
|
||||
HealthCheckHandler() http.HandlerFunc
|
||||
|
||||
// ReadinessHandler returns a handler for readiness checks.
|
||||
// Includes in-flight request count.
|
||||
ReadinessHandler() http.HandlerFunc
|
||||
|
||||
// InFlightRequests returns the current number of in-flight requests.
|
||||
InFlightRequests() int64
|
||||
|
||||
// IsShuttingDown returns true if the server is shutting down.
|
||||
IsShuttingDown() bool
|
||||
|
||||
// Wait blocks until shutdown is complete.
|
||||
Wait()
|
||||
}
|
||||
|
||||
// Manager defines the interface for a server manager.
|
||||
// It is responsible for managing the lifecycle of multiple server instances.
|
||||
type Manager interface {
|
||||
// Add registers a new server instance based on the provided configuration.
|
||||
// The server is not started until StartAll or Start is called on the instance.
|
||||
Add(cfg Config) (Instance, error)
|
||||
|
||||
// Get returns a server instance by its name.
|
||||
Get(name string) (Instance, error)
|
||||
|
||||
// Remove stops and removes a server instance by its name.
|
||||
Remove(name string) error
|
||||
|
||||
// StartAll starts all registered server instances that are not already running.
|
||||
StartAll() error
|
||||
|
||||
// StopAll gracefully shuts down all running server instances.
|
||||
// Executes shutdown callbacks and drains in-flight requests.
|
||||
StopAll() error
|
||||
|
||||
// StopAllWithContext gracefully shuts down all running server instances with a context.
|
||||
StopAllWithContext(ctx context.Context) error
|
||||
|
||||
// RestartAll gracefully restarts all running server instances.
|
||||
RestartAll() error
|
||||
|
||||
// List returns all registered server instances.
|
||||
List() []Instance
|
||||
|
||||
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
|
||||
// It handles SIGINT and SIGTERM signals and performs graceful shutdown with callbacks.
|
||||
ServeWithGracefulShutdown() error
|
||||
|
||||
// RegisterShutdownCallback registers a callback to be called during shutdown.
|
||||
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
|
||||
RegisterShutdownCallback(cb ShutdownCallback)
|
||||
}
|
||||
|
||||
// ShutdownCallback is a function called during graceful shutdown.
|
||||
type ShutdownCallback func(context.Context) error
|
||||
572
pkg/server/manager.go
Normal file
572
pkg/server/manager.go
Normal file
@@ -0,0 +1,572 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
"github.com/klauspost/compress/gzhttp"
|
||||
)
|
||||
|
||||
// gracefulServer wraps http.Server with graceful shutdown capabilities (internal type)
|
||||
type gracefulServer struct {
|
||||
server *http.Server
|
||||
shutdownTimeout time.Duration
|
||||
drainTimeout time.Duration
|
||||
inFlightRequests atomic.Int64
|
||||
isShuttingDown atomic.Bool
|
||||
shutdownOnce sync.Once
|
||||
shutdownComplete chan struct{}
|
||||
}
|
||||
|
||||
// trackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
|
||||
func (gs *gracefulServer) trackRequestsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if shutting down
|
||||
if gs.isShuttingDown.Load() {
|
||||
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment in-flight counter
|
||||
gs.inFlightRequests.Add(1)
|
||||
defer gs.inFlightRequests.Add(-1)
|
||||
|
||||
// Serve the request
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// shutdown performs graceful shutdown with request draining
|
||||
func (gs *gracefulServer) shutdown(ctx context.Context) error {
|
||||
var shutdownErr error
|
||||
|
||||
gs.shutdownOnce.Do(func() {
|
||||
logger.Info("Starting graceful shutdown...")
|
||||
|
||||
// Mark as shutting down (new requests will be rejected)
|
||||
gs.isShuttingDown.Store(true)
|
||||
|
||||
// Create context with timeout
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Wait for in-flight requests to complete (with drain timeout)
|
||||
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
|
||||
defer drainCancel()
|
||||
|
||||
shutdownErr = gs.drainRequests(drainCtx)
|
||||
if shutdownErr != nil {
|
||||
logger.Error("Error draining requests: %v", shutdownErr)
|
||||
}
|
||||
|
||||
// Shutdown the server
|
||||
logger.Info("Shutting down HTTP server...")
|
||||
if err := gs.server.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error("Error shutting down server: %v", err)
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Graceful shutdown complete")
|
||||
close(gs.shutdownComplete)
|
||||
})
|
||||
|
||||
return shutdownErr
|
||||
}
|
||||
|
||||
// drainRequests waits for in-flight requests to complete
|
||||
func (gs *gracefulServer) drainRequests(ctx context.Context) error {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
inFlight := gs.inFlightRequests.Load()
|
||||
|
||||
if inFlight == 0 {
|
||||
logger.Info("All requests drained in %v", time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
|
||||
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
|
||||
case <-ticker.C:
|
||||
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inFlightRequests returns the current number of in-flight requests
|
||||
func (gs *gracefulServer) inFlightRequestsCount() int64 {
|
||||
return gs.inFlightRequests.Load()
|
||||
}
|
||||
|
||||
// isShutdown returns true if the server is shutting down
|
||||
func (gs *gracefulServer) isShutdown() bool {
|
||||
return gs.isShuttingDown.Load()
|
||||
}
|
||||
|
||||
// wait blocks until shutdown is complete
|
||||
func (gs *gracefulServer) wait() {
|
||||
<-gs.shutdownComplete
|
||||
}
|
||||
|
||||
// healthCheckHandler returns a handler that responds to health checks
|
||||
func (gs *gracefulServer) healthCheckHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.isShutdown() {
|
||||
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write health check response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readinessHandler returns a handler for readiness checks
|
||||
func (gs *gracefulServer) readinessHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.isShutdown() {
|
||||
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
inFlight := gs.inFlightRequestsCount()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
// serverManager manages a collection of server instances with graceful shutdown support.
|
||||
type serverManager struct {
|
||||
instances map[string]Instance
|
||||
mu sync.RWMutex
|
||||
shutdownCallbacks []ShutdownCallback
|
||||
callbacksMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new server manager.
|
||||
func NewManager() Manager {
|
||||
return &serverManager{
|
||||
instances: make(map[string]Instance),
|
||||
shutdownCallbacks: make([]ShutdownCallback, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Add registers a new server instance.
|
||||
func (sm *serverManager) Add(cfg Config) (Instance, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("server name cannot be empty")
|
||||
}
|
||||
if _, exists := sm.instances[cfg.Name]; exists {
|
||||
return nil, fmt.Errorf("server with name '%s' already exists", cfg.Name)
|
||||
}
|
||||
|
||||
instance, err := newInstance(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sm.instances[cfg.Name] = instance
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Get returns a server instance by its name.
|
||||
func (sm *serverManager) Get(name string) (Instance, error) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
instance, exists := sm.instances[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("server with name '%s' not found", name)
|
||||
}
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Remove stops and removes a server instance by its name.
|
||||
func (sm *serverManager) Remove(name string) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
instance, exists := sm.instances[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("server with name '%s' not found", name)
|
||||
}
|
||||
|
||||
// Stop the server if it's running
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := instance.Stop(ctx); err != nil {
|
||||
logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err)
|
||||
}
|
||||
|
||||
delete(sm.instances, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartAll starts all registered server instances.
|
||||
func (sm *serverManager) StartAll() error {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var startErrors []error
|
||||
for name, instance := range sm.instances {
|
||||
if err := instance.Start(); err != nil {
|
||||
startErrors = append(startErrors, fmt.Errorf("failed to start server '%s': %w", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(startErrors) > 0 {
|
||||
return fmt.Errorf("encountered errors while starting servers: %v", startErrors)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopAll gracefully shuts down all running server instances.
|
||||
func (sm *serverManager) StopAll() error {
|
||||
return sm.StopAllWithContext(context.Background())
|
||||
}
|
||||
|
||||
// StopAllWithContext gracefully shuts down all running server instances with a context.
|
||||
func (sm *serverManager) StopAllWithContext(ctx context.Context) error {
|
||||
sm.mu.RLock()
|
||||
instancesToStop := make([]Instance, 0, len(sm.instances))
|
||||
for _, instance := range sm.instances {
|
||||
instancesToStop = append(instancesToStop, instance)
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
logger.Info("Shutting down all servers...")
|
||||
|
||||
// Execute shutdown callbacks first
|
||||
sm.callbacksMu.Lock()
|
||||
callbacks := make([]ShutdownCallback, len(sm.shutdownCallbacks))
|
||||
copy(callbacks, sm.shutdownCallbacks)
|
||||
sm.callbacksMu.Unlock()
|
||||
|
||||
if len(callbacks) > 0 {
|
||||
logger.Info("Executing %d shutdown callbacks...", len(callbacks))
|
||||
for i, cb := range callbacks {
|
||||
if err := cb(ctx); err != nil {
|
||||
logger.Error("Shutdown callback %d failed: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop all instances in parallel
|
||||
var shutdownErrors []error
|
||||
var wg sync.WaitGroup
|
||||
var errorsMu sync.Mutex
|
||||
|
||||
for _, instance := range instancesToStop {
|
||||
wg.Add(1)
|
||||
go func(inst Instance) {
|
||||
defer wg.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
if err := inst.Stop(shutdownCtx); err != nil {
|
||||
errorsMu.Lock()
|
||||
shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Name(), err))
|
||||
errorsMu.Unlock()
|
||||
}
|
||||
}(instance)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(shutdownErrors) > 0 {
|
||||
return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors)
|
||||
}
|
||||
logger.Info("All servers stopped gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartAll gracefully restarts all running server instances.
|
||||
func (sm *serverManager) RestartAll() error {
|
||||
logger.Info("Restarting all servers...")
|
||||
if err := sm.StopAll(); err != nil {
|
||||
return fmt.Errorf("failed to stop servers during restart: %w", err)
|
||||
}
|
||||
|
||||
// Give ports time to be released
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
if err := sm.StartAll(); err != nil {
|
||||
return fmt.Errorf("failed to start servers during restart: %w", err)
|
||||
}
|
||||
logger.Info("All servers restarted successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns all registered server instances.
|
||||
func (sm *serverManager) List() []Instance {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
instances := make([]Instance, 0, len(sm.instances))
|
||||
for _, instance := range sm.instances {
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
return instances
|
||||
}
|
||||
|
||||
// RegisterShutdownCallback registers a callback to be called during shutdown.
|
||||
func (sm *serverManager) RegisterShutdownCallback(cb ShutdownCallback) {
|
||||
sm.callbacksMu.Lock()
|
||||
defer sm.callbacksMu.Unlock()
|
||||
sm.shutdownCallbacks = append(sm.shutdownCallbacks, cb)
|
||||
}
|
||||
|
||||
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
|
||||
func (sm *serverManager) ServeWithGracefulShutdown() error {
|
||||
// Start all servers
|
||||
if err := sm.StartAll(); err != nil {
|
||||
return fmt.Errorf("failed to start servers: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("All servers started. Waiting for shutdown signal...")
|
||||
|
||||
// Wait for interrupt signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
sig := <-sigChan
|
||||
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
|
||||
|
||||
// Create context with timeout for shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return sm.StopAllWithContext(ctx)
|
||||
}
|
||||
|
||||
// serverInstance is a concrete implementation of the Instance interface.
|
||||
// It wraps gracefulServer to provide graceful shutdown capabilities.
|
||||
type serverInstance struct {
|
||||
cfg Config
|
||||
gracefulServer *gracefulServer
|
||||
certFile string // Path to certificate file (may be temporary for self-signed)
|
||||
keyFile string // Path to key file (may be temporary for self-signed)
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
serverErr chan error
|
||||
}
|
||||
|
||||
// newInstance creates a new, unstarted server instance from a config.
|
||||
func newInstance(cfg Config) (*serverInstance, error) {
|
||||
if cfg.Handler == nil {
|
||||
return nil, fmt.Errorf("handler cannot be nil")
|
||||
}
|
||||
|
||||
// Set default timeouts
|
||||
if cfg.ShutdownTimeout == 0 {
|
||||
cfg.ShutdownTimeout = 30 * time.Second
|
||||
}
|
||||
if cfg.DrainTimeout == 0 {
|
||||
cfg.DrainTimeout = 25 * time.Second
|
||||
}
|
||||
if cfg.ReadTimeout == 0 {
|
||||
cfg.ReadTimeout = 15 * time.Second
|
||||
}
|
||||
if cfg.WriteTimeout == 0 {
|
||||
cfg.WriteTimeout = 15 * time.Second
|
||||
}
|
||||
if cfg.IdleTimeout == 0 {
|
||||
cfg.IdleTimeout = 60 * time.Second
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
var handler http.Handler = cfg.Handler
|
||||
|
||||
// Wrap with GZIP handler if enabled
|
||||
if cfg.GZIP {
|
||||
gz, err := gzhttp.NewWrapper()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err)
|
||||
}
|
||||
handler = gz(handler)
|
||||
}
|
||||
|
||||
// Wrap with the panic recovery middleware
|
||||
handler = middleware.PanicRecovery(handler)
|
||||
|
||||
// Configure TLS if any TLS option is enabled
|
||||
tlsConfig, certFile, keyFile, err := configureTLS(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to configure TLS: %w", err)
|
||||
}
|
||||
|
||||
// Create gracefulServer
|
||||
gracefulSrv := &gracefulServer{
|
||||
server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
IdleTimeout: cfg.IdleTimeout,
|
||||
TLSConfig: tlsConfig,
|
||||
},
|
||||
shutdownTimeout: cfg.ShutdownTimeout,
|
||||
drainTimeout: cfg.DrainTimeout,
|
||||
shutdownComplete: make(chan struct{}),
|
||||
}
|
||||
|
||||
return &serverInstance{
|
||||
cfg: cfg,
|
||||
gracefulServer: gracefulSrv,
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverErr: make(chan error, 1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins serving requests in a new goroutine.
|
||||
func (s *serverInstance) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("server '%s' is already running", s.cfg.Name)
|
||||
}
|
||||
|
||||
// Determine if we're using TLS
|
||||
useTLS := s.cfg.SSLCert != "" || s.cfg.SSLKey != "" || s.cfg.SelfSignedSSL || s.cfg.AutoTLS
|
||||
|
||||
// Wrap handler with request tracking
|
||||
s.gracefulServer.server.Handler = s.gracefulServer.trackRequestsMiddleware(s.gracefulServer.server.Handler)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
s.mu.Unlock()
|
||||
logger.Info("Server '%s' stopped.", s.cfg.Name)
|
||||
}()
|
||||
|
||||
var err error
|
||||
protocol := "HTTP"
|
||||
|
||||
if useTLS {
|
||||
protocol = "HTTPS"
|
||||
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||
|
||||
// For AutoTLS, we need to use a TLS listener
|
||||
if s.cfg.AutoTLS {
|
||||
// Create listener
|
||||
ln, lnErr := net.Listen("tcp", s.gracefulServer.server.Addr)
|
||||
if lnErr != nil {
|
||||
err = fmt.Errorf("failed to create listener: %w", lnErr)
|
||||
} else {
|
||||
// Wrap with TLS
|
||||
tlsListener := tls.NewListener(ln, s.gracefulServer.server.TLSConfig)
|
||||
err = s.gracefulServer.server.Serve(tlsListener)
|
||||
}
|
||||
} else {
|
||||
// Use certificate files (regular SSL or self-signed)
|
||||
err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||
err = s.gracefulServer.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// If the server stopped for a reason other than a graceful shutdown, log and report the error.
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("Server '%s' failed: %v", s.cfg.Name, err)
|
||||
select {
|
||||
case s.serverErr <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.running = true
|
||||
// A small delay to allow the goroutine to start and potentially fail on binding.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check if the server failed to start
|
||||
select {
|
||||
case err := <-s.serverErr:
|
||||
s.running = false
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the server.
|
||||
func (s *serverInstance) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running {
|
||||
return nil // Already stopped
|
||||
}
|
||||
|
||||
logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name)
|
||||
err := s.gracefulServer.shutdown(ctx)
|
||||
if err == nil {
|
||||
s.running = false
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Addr returns the network address the server is listening on.
|
||||
func (s *serverInstance) Addr() string {
|
||||
return s.gracefulServer.server.Addr
|
||||
}
|
||||
|
||||
// Name returns the server instance name.
|
||||
func (s *serverInstance) Name() string {
|
||||
return s.cfg.Name
|
||||
}
|
||||
|
||||
// HealthCheckHandler returns a handler that responds to health checks.
|
||||
func (s *serverInstance) HealthCheckHandler() http.HandlerFunc {
|
||||
return s.gracefulServer.healthCheckHandler()
|
||||
}
|
||||
|
||||
// ReadinessHandler returns a handler for readiness checks.
|
||||
func (s *serverInstance) ReadinessHandler() http.HandlerFunc {
|
||||
return s.gracefulServer.readinessHandler()
|
||||
}
|
||||
|
||||
// InFlightRequests returns the current number of in-flight requests.
|
||||
func (s *serverInstance) InFlightRequests() int64 {
|
||||
return s.gracefulServer.inFlightRequestsCount()
|
||||
}
|
||||
|
||||
// IsShuttingDown returns true if the server is shutting down.
|
||||
func (s *serverInstance) IsShuttingDown() bool {
|
||||
return s.gracefulServer.isShutdown()
|
||||
}
|
||||
|
||||
// Wait blocks until shutdown is complete.
|
||||
func (s *serverInstance) Wait() {
|
||||
s.gracefulServer.wait()
|
||||
}
|
||||
328
pkg/server/manager_test.go
Normal file
328
pkg/server/manager_test.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// getFreePort asks the kernel for a free open port that is ready to use.
|
||||
func getFreePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func TestServerManagerLifecycle(t *testing.T) {
|
||||
// Initialize logger for test output
|
||||
logger.Init(true)
|
||||
|
||||
// Create a new server manager
|
||||
sm := NewManager()
|
||||
|
||||
// Define a simple test handler
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
// Get a free port for the server to listen on to avoid conflicts
|
||||
testPort := getFreePort(t)
|
||||
|
||||
// Add a new server configuration
|
||||
serverConfig := Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: testHandler,
|
||||
}
|
||||
instance, err := sm.Add(serverConfig)
|
||||
require.NoError(t, err, "should be able to add a new server")
|
||||
require.NotNil(t, instance, "added instance should not be nil")
|
||||
|
||||
// --- Test StartAll ---
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err, "StartAll should not return an error")
|
||||
|
||||
// Give the server a moment to start up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// --- Verify Server is Running ---
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
resp, err := client.Get(url)
|
||||
require.NoError(t, err, "should be able to make a request to the running server")
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "expected status OK from the test server")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
assert.Equal(t, "Hello, World!", string(body), "response body should match expected value")
|
||||
|
||||
// --- Test Get ---
|
||||
retrievedInstance, err := sm.Get("TestServer")
|
||||
require.NoError(t, err, "should be able to get server by name")
|
||||
assert.Equal(t, instance.Addr(), retrievedInstance.Addr(), "retrieved instance should be the same")
|
||||
|
||||
// --- Test List ---
|
||||
instanceList := sm.List()
|
||||
require.Len(t, instanceList, 1, "list should contain one instance")
|
||||
assert.Equal(t, instance.Addr(), instanceList[0].Addr(), "listed instance should be the same")
|
||||
|
||||
// --- Test StopAll ---
|
||||
err = sm.StopAll()
|
||||
require.NoError(t, err, "StopAll should not return an error")
|
||||
|
||||
// Give the server a moment to shut down
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// --- Verify Server is Stopped ---
|
||||
_, err = client.Get(url)
|
||||
require.Error(t, err, "should not be able to make a request to a stopped server")
|
||||
|
||||
// --- Test Remove ---
|
||||
err = sm.Remove("TestServer")
|
||||
require.NoError(t, err, "should be able to remove a server")
|
||||
|
||||
_, err = sm.Get("TestServer")
|
||||
require.Error(t, err, "should not be able to get a removed server")
|
||||
}
|
||||
|
||||
func TestManagerErrorCases(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
testPort := getFreePort(t)
|
||||
|
||||
// --- Test Add Duplicate Name ---
|
||||
config1 := Config{Name: "Duplicate", Host: "localhost", Port: testPort, Handler: http.NewServeMux()}
|
||||
_, err := sm.Add(config1)
|
||||
require.NoError(t, err)
|
||||
|
||||
config2 := Config{Name: "Duplicate", Host: "localhost", Port: getFreePort(t), Handler: http.NewServeMux()}
|
||||
_, err = sm.Add(config2)
|
||||
require.Error(t, err, "should not be able to add a server with a duplicate name")
|
||||
|
||||
// --- Test Get Non-existent ---
|
||||
_, err = sm.Get("NonExistent")
|
||||
require.Error(t, err, "should get an error for a non-existent server")
|
||||
|
||||
// --- Test Add with Nil Handler ---
|
||||
config3 := Config{Name: "NilHandler", Host: "localhost", Port: getFreePort(t), Handler: nil}
|
||||
_, err = sm.Add(config3)
|
||||
require.Error(t, err, "should not be able to add a server with a nil handler")
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
requestsHandled := 0
|
||||
var requestsMu sync.Mutex
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestsMu.Lock()
|
||||
requestsHandled++
|
||||
requestsMu.Unlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
testPort := getFreePort(t)
|
||||
instance, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: handler,
|
||||
DrainTimeout: 2 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Send some concurrent requests
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
resp, err := client.Get(url)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait a bit for requests to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check in-flight requests
|
||||
inFlight := instance.InFlightRequests()
|
||||
assert.Greater(t, inFlight, int64(0), "Should have in-flight requests")
|
||||
|
||||
// Stop the server
|
||||
err = sm.StopAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for all requests to complete
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests were handled
|
||||
requestsMu.Lock()
|
||||
handled := requestsHandled
|
||||
requestsMu.Unlock()
|
||||
assert.GreaterOrEqual(t, handled, 1, "At least some requests should have been handled")
|
||||
|
||||
// Verify no in-flight requests
|
||||
assert.Equal(t, int64(0), instance.InFlightRequests(), "Should have no in-flight requests after shutdown")
|
||||
}
|
||||
|
||||
func TestHealthAndReadinessEndpoints(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
testPort := getFreePort(t)
|
||||
|
||||
instance, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: mux,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add health and readiness endpoints
|
||||
mux.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
mux.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
baseURL := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
|
||||
// Test health endpoint
|
||||
resp, err := client.Get(baseURL + "/health")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
assert.Contains(t, string(body), "healthy")
|
||||
|
||||
// Test readiness endpoint
|
||||
resp, err = client.Get(baseURL + "/ready")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
assert.Contains(t, string(body), "ready")
|
||||
assert.Contains(t, string(body), "in_flight_requests")
|
||||
|
||||
// Stop the server
|
||||
sm.StopAll()
|
||||
}
|
||||
|
||||
func TestRequestRejectionDuringShutdown(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
testPort := getFreePort(t)
|
||||
_, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: handler,
|
||||
DrainTimeout: 1 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Start shutdown in background
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
sm.StopAll()
|
||||
}()
|
||||
|
||||
// Give shutdown time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to make a request after shutdown started
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
resp, err := client.Get(url)
|
||||
|
||||
// The request should either fail (connection refused) or get 503
|
||||
if err == nil {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Should get 503 during shutdown")
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownCallbacks(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
callbackExecuted := false
|
||||
var callbackMu sync.Mutex
|
||||
|
||||
sm.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
callbackMu.Lock()
|
||||
callbackExecuted = true
|
||||
callbackMu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
testPort := getFreePort(t)
|
||||
_, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: http.NewServeMux(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = sm.StopAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
callbackMu.Lock()
|
||||
executed := callbackExecuted
|
||||
callbackMu.Unlock()
|
||||
|
||||
assert.True(t, executed, "Shutdown callback should have been executed")
|
||||
}
|
||||
@@ -1,296 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// GracefulServer wraps http.Server with graceful shutdown capabilities
|
||||
type GracefulServer struct {
|
||||
server *http.Server
|
||||
shutdownTimeout time.Duration
|
||||
drainTimeout time.Duration
|
||||
inFlightRequests atomic.Int64
|
||||
isShuttingDown atomic.Bool
|
||||
shutdownOnce sync.Once
|
||||
shutdownComplete chan struct{}
|
||||
}
|
||||
|
||||
// Config holds configuration for the graceful server
|
||||
type Config struct {
|
||||
// Addr is the server address (e.g., ":8080")
|
||||
Addr string
|
||||
|
||||
// Handler is the HTTP handler
|
||||
Handler http.Handler
|
||||
|
||||
// ShutdownTimeout is the maximum time to wait for graceful shutdown
|
||||
// Default: 30 seconds
|
||||
ShutdownTimeout time.Duration
|
||||
|
||||
// DrainTimeout is the time to wait for in-flight requests to complete
|
||||
// before forcing shutdown. Default: 25 seconds
|
||||
DrainTimeout time.Duration
|
||||
|
||||
// ReadTimeout is the maximum duration for reading the entire request
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// WriteTimeout is the maximum duration before timing out writes of the response
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// IdleTimeout is the maximum amount of time to wait for the next request
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewGracefulServer creates a new graceful server
|
||||
func NewGracefulServer(config Config) *GracefulServer {
|
||||
if config.ShutdownTimeout == 0 {
|
||||
config.ShutdownTimeout = 30 * time.Second
|
||||
}
|
||||
if config.DrainTimeout == 0 {
|
||||
config.DrainTimeout = 25 * time.Second
|
||||
}
|
||||
if config.ReadTimeout == 0 {
|
||||
config.ReadTimeout = 10 * time.Second
|
||||
}
|
||||
if config.WriteTimeout == 0 {
|
||||
config.WriteTimeout = 10 * time.Second
|
||||
}
|
||||
if config.IdleTimeout == 0 {
|
||||
config.IdleTimeout = 120 * time.Second
|
||||
}
|
||||
|
||||
gs := &GracefulServer{
|
||||
server: &http.Server{
|
||||
Addr: config.Addr,
|
||||
Handler: config.Handler,
|
||||
ReadTimeout: config.ReadTimeout,
|
||||
WriteTimeout: config.WriteTimeout,
|
||||
IdleTimeout: config.IdleTimeout,
|
||||
},
|
||||
shutdownTimeout: config.ShutdownTimeout,
|
||||
drainTimeout: config.DrainTimeout,
|
||||
shutdownComplete: make(chan struct{}),
|
||||
}
|
||||
|
||||
return gs
|
||||
}
|
||||
|
||||
// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
|
||||
func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if shutting down
|
||||
if gs.isShuttingDown.Load() {
|
||||
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment in-flight counter
|
||||
gs.inFlightRequests.Add(1)
|
||||
defer gs.inFlightRequests.Add(-1)
|
||||
|
||||
// Serve the request
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ListenAndServe starts the server and handles graceful shutdown
|
||||
func (gs *GracefulServer) ListenAndServe() error {
|
||||
// Wrap handler with request tracking
|
||||
gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler)
|
||||
|
||||
// Start server in goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
logger.Info("Starting server on %s", gs.server.Addr)
|
||||
if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
serverErr <- err
|
||||
}
|
||||
close(serverErr)
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
select {
|
||||
case err := <-serverErr:
|
||||
return err
|
||||
case sig := <-sigChan:
|
||||
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
|
||||
return gs.Shutdown(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown performs graceful shutdown with request draining
|
||||
func (gs *GracefulServer) Shutdown(ctx context.Context) error {
|
||||
var shutdownErr error
|
||||
|
||||
gs.shutdownOnce.Do(func() {
|
||||
logger.Info("Starting graceful shutdown...")
|
||||
|
||||
// Mark as shutting down (new requests will be rejected)
|
||||
gs.isShuttingDown.Store(true)
|
||||
|
||||
// Create context with timeout
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Wait for in-flight requests to complete (with drain timeout)
|
||||
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
|
||||
defer drainCancel()
|
||||
|
||||
shutdownErr = gs.drainRequests(drainCtx)
|
||||
if shutdownErr != nil {
|
||||
logger.Error("Error draining requests: %v", shutdownErr)
|
||||
}
|
||||
|
||||
// Shutdown the server
|
||||
logger.Info("Shutting down HTTP server...")
|
||||
if err := gs.server.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error("Error shutting down server: %v", err)
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Graceful shutdown complete")
|
||||
close(gs.shutdownComplete)
|
||||
})
|
||||
|
||||
return shutdownErr
|
||||
}
|
||||
|
||||
// drainRequests waits for in-flight requests to complete
|
||||
func (gs *GracefulServer) drainRequests(ctx context.Context) error {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
inFlight := gs.inFlightRequests.Load()
|
||||
|
||||
if inFlight == 0 {
|
||||
logger.Info("All requests drained in %v", time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
|
||||
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
|
||||
case <-ticker.C:
|
||||
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InFlightRequests returns the current number of in-flight requests
|
||||
func (gs *GracefulServer) InFlightRequests() int64 {
|
||||
return gs.inFlightRequests.Load()
|
||||
}
|
||||
|
||||
// IsShuttingDown returns true if the server is shutting down
|
||||
func (gs *GracefulServer) IsShuttingDown() bool {
|
||||
return gs.isShuttingDown.Load()
|
||||
}
|
||||
|
||||
// Wait blocks until shutdown is complete
|
||||
func (gs *GracefulServer) Wait() {
|
||||
<-gs.shutdownComplete
|
||||
}
|
||||
|
||||
// HealthCheckHandler returns a handler that responds to health checks
|
||||
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down
|
||||
func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.IsShuttingDown() {
|
||||
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write. %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadinessHandler returns a handler for readiness checks
|
||||
// Includes in-flight request count
|
||||
func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.IsShuttingDown() {
|
||||
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
inFlight := gs.InFlightRequests()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
// ShutdownCallback is a function called during shutdown
|
||||
type ShutdownCallback func(context.Context) error
|
||||
|
||||
// shutdownCallbacks stores registered shutdown callbacks
|
||||
var (
|
||||
shutdownCallbacks []ShutdownCallback
|
||||
shutdownCallbacksMu sync.Mutex
|
||||
)
|
||||
|
||||
// RegisterShutdownCallback registers a callback to be called during shutdown
|
||||
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
|
||||
func RegisterShutdownCallback(cb ShutdownCallback) {
|
||||
shutdownCallbacksMu.Lock()
|
||||
defer shutdownCallbacksMu.Unlock()
|
||||
shutdownCallbacks = append(shutdownCallbacks, cb)
|
||||
}
|
||||
|
||||
// executeShutdownCallbacks runs all registered shutdown callbacks
|
||||
func executeShutdownCallbacks(ctx context.Context) error {
|
||||
shutdownCallbacksMu.Lock()
|
||||
callbacks := make([]ShutdownCallback, len(shutdownCallbacks))
|
||||
copy(callbacks, shutdownCallbacks)
|
||||
shutdownCallbacksMu.Unlock()
|
||||
|
||||
var errors []error
|
||||
for i, cb := range callbacks {
|
||||
logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks))
|
||||
if err := cb(ctx); err != nil {
|
||||
logger.Error("Shutdown callback %d failed: %v", i+1, err)
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("shutdown callbacks failed: %v", errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShutdownWithCallbacks performs shutdown and executes all registered callbacks
|
||||
func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error {
|
||||
// Execute callbacks first
|
||||
if err := executeShutdownCallbacks(ctx); err != nil {
|
||||
logger.Error("Error executing shutdown callbacks: %v", err)
|
||||
}
|
||||
|
||||
// Then shutdown the server
|
||||
return gs.Shutdown(ctx)
|
||||
}
|
||||
@@ -1,231 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGracefulServerTrackRequests(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
})
|
||||
|
||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||
|
||||
// Start some requests
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait a bit for requests to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Check in-flight count
|
||||
inFlight := srv.InFlightRequests()
|
||||
if inFlight == 0 {
|
||||
t.Error("Should have in-flight requests")
|
||||
}
|
||||
|
||||
// Wait for all requests to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check that counter is back to zero
|
||||
inFlight = srv.InFlightRequests()
|
||||
if inFlight != 0 {
|
||||
t.Errorf("In-flight requests should be 0, got %d", inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
})
|
||||
|
||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||
|
||||
// Mark as shutting down
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
// Try to make a request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// Should get 503
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckHandler(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
})
|
||||
|
||||
handler := srv.HealthCheckHandler()
|
||||
|
||||
// Healthy
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if w.Body.String() != `{"status":"healthy"}` {
|
||||
t.Errorf("Unexpected body: %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
// Shutting down
|
||||
t.Run("ShuttingDown", func(t *testing.T) {
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadinessHandler(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
})
|
||||
|
||||
handler := srv.ReadinessHandler()
|
||||
|
||||
// Ready with no in-flight requests
|
||||
t.Run("Ready", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if body != `{"ready":true,"in_flight_requests":0}` {
|
||||
t.Errorf("Unexpected body: %s", body)
|
||||
}
|
||||
})
|
||||
|
||||
// Not ready during shutdown
|
||||
t.Run("NotReady", func(t *testing.T) {
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShutdownCallbacks(t *testing.T) {
|
||||
callbackExecuted := false
|
||||
|
||||
RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
callbackExecuted = true
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := executeShutdownCallbacks(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("executeShutdownCallbacks() error = %v", err)
|
||||
}
|
||||
|
||||
if !callbackExecuted {
|
||||
t.Error("Shutdown callback was not executed")
|
||||
}
|
||||
|
||||
// Reset for other tests
|
||||
shutdownCallbacks = nil
|
||||
}
|
||||
|
||||
func TestDrainRequests(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
DrainTimeout: 1 * time.Second,
|
||||
})
|
||||
|
||||
// Simulate in-flight requests
|
||||
srv.inFlightRequests.Add(3)
|
||||
|
||||
// Start draining in background
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Simulate requests completing
|
||||
srv.inFlightRequests.Add(-3)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := srv.drainRequests(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("drainRequests() error = %v", err)
|
||||
}
|
||||
|
||||
if srv.InFlightRequests() != 0 {
|
||||
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainRequestsTimeout(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
DrainTimeout: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Simulate in-flight requests that don't complete
|
||||
srv.inFlightRequests.Add(5)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := srv.drainRequests(ctx)
|
||||
if err == nil {
|
||||
t.Error("drainRequests() should timeout with error")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
srv.inFlightRequests.Add(-5)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
|
||||
// Including here for completeness of server tests
|
||||
}
|
||||
190
pkg/server/tls.go
Normal file
190
pkg/server/tls.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// generateSelfSignedCert generates a self-signed certificate for the given host.
|
||||
// Returns the certificate and private key in PEM format.
|
||||
func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) {
|
||||
// Generate private key
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
// Create certificate template
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
|
||||
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"ResolveSpec Self-Signed"},
|
||||
CommonName: host,
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
// Add host as DNS name or IP address
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
template.IPAddresses = []net.IP{ip}
|
||||
} else {
|
||||
template.DNSNames = []string{host}
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
// Encode certificate to PEM
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
|
||||
// Encode private key to PEM
|
||||
privBytes, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
|
||||
}
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
|
||||
|
||||
return certPEM, keyPEM, nil
|
||||
}
|
||||
|
||||
// saveCertToTempFiles saves certificate and key PEM data to temporary files.
|
||||
// Returns the file paths for the certificate and key.
|
||||
func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile string, err error) {
|
||||
// Create temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "resolvespec-certs-*")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
certFile = filepath.Join(tmpDir, "cert.pem")
|
||||
keyFile = filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
// Write certificate
|
||||
if err := os.WriteFile(certFile, certPEM, 0600); err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
return "", "", fmt.Errorf("failed to write certificate: %w", err)
|
||||
}
|
||||
|
||||
// Write key
|
||||
if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
return "", "", fmt.Errorf("failed to write private key: %w", err)
|
||||
}
|
||||
|
||||
return certFile, keyFile, nil
|
||||
}
|
||||
|
||||
// setupAutoTLS configures automatic TLS certificate management using Let's Encrypt.
|
||||
// Returns a TLS config that can be used with http.Server.
|
||||
func setupAutoTLS(domains []string, email, cacheDir string) (*tls.Config, error) {
|
||||
if len(domains) == 0 {
|
||||
return nil, fmt.Errorf("at least one domain must be specified for AutoTLS")
|
||||
}
|
||||
|
||||
// Set default cache directory
|
||||
if cacheDir == "" {
|
||||
cacheDir = "./certs-cache"
|
||||
}
|
||||
|
||||
// Create cache directory if it doesn't exist
|
||||
if err := os.MkdirAll(cacheDir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create certificate cache directory: %w", err)
|
||||
}
|
||||
|
||||
// Create autocert manager
|
||||
m := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(cacheDir),
|
||||
HostPolicy: autocert.HostWhitelist(domains...),
|
||||
Email: email,
|
||||
}
|
||||
|
||||
// Create TLS config
|
||||
tlsConfig := m.TLSConfig()
|
||||
tlsConfig.MinVersion = tls.VersionTLS12
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// configureTLS configures TLS for the server based on the provided configuration.
|
||||
// Returns the TLS config and certificate/key file paths (if applicable).
|
||||
func configureTLS(cfg Config) (*tls.Config, string, string, error) {
|
||||
// Option 1: Certificate files provided
|
||||
if cfg.SSLCert != "" && cfg.SSLKey != "" {
|
||||
// Validate that files exist
|
||||
if _, err := os.Stat(cfg.SSLCert); os.IsNotExist(err) {
|
||||
return nil, "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert)
|
||||
}
|
||||
if _, err := os.Stat(cfg.SSLKey); os.IsNotExist(err) {
|
||||
return nil, "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey)
|
||||
}
|
||||
|
||||
// Return basic TLS config - cert/key will be loaded by ListenAndServeTLS
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
return tlsConfig, cfg.SSLCert, cfg.SSLKey, nil
|
||||
}
|
||||
|
||||
// Option 2: Auto TLS (Let's Encrypt)
|
||||
if cfg.AutoTLS {
|
||||
tlsConfig, err := setupAutoTLS(cfg.AutoTLSDomains, cfg.AutoTLSEmail, cfg.AutoTLSCacheDir)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to setup AutoTLS: %w", err)
|
||||
}
|
||||
return tlsConfig, "", "", nil
|
||||
}
|
||||
|
||||
// Option 3: Self-signed certificate
|
||||
if cfg.SelfSignedSSL {
|
||||
host := cfg.Host
|
||||
if host == "" || host == "0.0.0.0" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
certPEM, keyPEM, err := generateSelfSignedCert(host)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err)
|
||||
}
|
||||
|
||||
certFile, keyFile, err := saveCertToTempFiles(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to save self-signed certificate: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
return tlsConfig, certFile, keyFile, nil
|
||||
}
|
||||
|
||||
return nil, "", "", nil
|
||||
}
|
||||
Reference in New Issue
Block a user