mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
Middleware enhancements
This commit is contained in:
parent
b741958895
commit
2a84652dba
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/uptrace/bunrouter"
|
"github.com/uptrace/bunrouter"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface
|
// BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface
|
||||||
@ -36,7 +37,10 @@ func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request)
|
|||||||
// This method would be used when we need to serve through our interface
|
// This method would be used when we need to serve through our interface
|
||||||
// For now, we'll work directly with the underlying router
|
// For now, we'll work directly with the underlying router
|
||||||
w.WriteHeader(http.StatusNotImplemented)
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetBunRouter() for direct access"}`))
|
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetBunRouter() for direct access"}`))
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to write. %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetBunRouter returns the underlying bunrouter for direct access
|
// GetBunRouter returns the underlying bunrouter for direct access
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MuxAdapter adapts Gorilla Mux to work with our Router interface
|
// MuxAdapter adapts Gorilla Mux to work with our Router interface
|
||||||
@ -33,7 +34,10 @@ func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
|
|||||||
// This method would be used when we need to serve through our interface
|
// This method would be used when we need to serve through our interface
|
||||||
// For now, we'll work directly with the underlying router
|
// For now, we'll work directly with the underlying router
|
||||||
w.WriteHeader(http.StatusNotImplemented)
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetMuxRouter() for direct access"}`))
|
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetMuxRouter() for direct access"}`))
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to write. %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MuxRouteRegistration implements RouteRegistration for Mux
|
// MuxRouteRegistration implements RouteRegistration for Mux
|
||||||
|
|||||||
@ -45,4 +45,3 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e
|
|||||||
OriginalType: originalType,
|
OriginalType: originalType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,8 @@ package metrics
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider defines the interface for metric collection
|
// Provider defines the interface for metric collection
|
||||||
@ -57,12 +59,15 @@ func (n *NoOpProvider) IncRequestsInFlight()
|
|||||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||||
}
|
}
|
||||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||||
func (n *NoOpProvider) Handler() http.Handler {
|
func (n *NoOpProvider) Handler() http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
w.Write([]byte("Metrics provider not configured"))
|
_, err := w.Write([]byte("Metrics provider not configured"))
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to write. %v", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,14 +12,14 @@ import (
|
|||||||
|
|
||||||
// PrometheusProvider implements the Provider interface using Prometheus
|
// PrometheusProvider implements the Provider interface using Prometheus
|
||||||
type PrometheusProvider struct {
|
type PrometheusProvider struct {
|
||||||
requestDuration *prometheus.HistogramVec
|
requestDuration *prometheus.HistogramVec
|
||||||
requestTotal *prometheus.CounterVec
|
requestTotal *prometheus.CounterVec
|
||||||
requestsInFlight prometheus.Gauge
|
requestsInFlight prometheus.Gauge
|
||||||
dbQueryDuration *prometheus.HistogramVec
|
dbQueryDuration *prometheus.HistogramVec
|
||||||
dbQueryTotal *prometheus.CounterVec
|
dbQueryTotal *prometheus.CounterVec
|
||||||
cacheHits *prometheus.CounterVec
|
cacheHits *prometheus.CounterVec
|
||||||
cacheMisses *prometheus.CounterVec
|
cacheMisses *prometheus.CounterVec
|
||||||
cacheSize *prometheus.GaugeVec
|
cacheSize *prometheus.GaugeVec
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||||
|
|||||||
@ -1,6 +1,14 @@
|
|||||||
# Middleware Package
|
# Middleware Package
|
||||||
|
|
||||||
HTTP middleware utilities including rate limiting.
|
HTTP middleware utilities for security and performance.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
1. [Rate Limiting](#rate-limiting)
|
||||||
|
2. [Request Size Limits](#request-size-limits)
|
||||||
|
3. [Input Sanitization](#input-sanitization)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Rate Limiting
|
## Rate Limiting
|
||||||
|
|
||||||
@ -370,3 +378,429 @@ func healthHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Request Size Limits
|
||||||
|
|
||||||
|
Protect against oversized request bodies with configurable size limits.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
|
||||||
|
// Default: 10MB limit
|
||||||
|
sizeLimiter := middleware.NewRequestSizeLimiter(0)
|
||||||
|
router.Use(sizeLimiter.Middleware)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Size Limit
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 5MB limit
|
||||||
|
sizeLimiter := middleware.NewRequestSizeLimiter(5 * 1024 * 1024)
|
||||||
|
router.Use(sizeLimiter.Middleware)
|
||||||
|
|
||||||
|
// Or use constants
|
||||||
|
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size5MB)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Available Size Constants
|
||||||
|
|
||||||
|
```go
|
||||||
|
middleware.Size1MB // 1 MB
|
||||||
|
middleware.Size5MB // 5 MB
|
||||||
|
middleware.Size10MB // 10 MB (default)
|
||||||
|
middleware.Size50MB // 50 MB
|
||||||
|
middleware.Size100MB // 100 MB
|
||||||
|
```
|
||||||
|
|
||||||
|
### Different Limits Per Route
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// File upload endpoint: 50MB
|
||||||
|
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
|
||||||
|
uploadRouter := router.PathPrefix("/upload").Subrouter()
|
||||||
|
uploadRouter.Use(uploadLimiter.Middleware)
|
||||||
|
|
||||||
|
// API endpoints: 1MB
|
||||||
|
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
|
||||||
|
apiRouter := router.PathPrefix("/api").Subrouter()
|
||||||
|
apiRouter.Use(apiLimiter.Middleware)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dynamic Size Limits
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Custom size based on request
|
||||||
|
sizeFunc := func(r *http.Request) int64 {
|
||||||
|
// Premium users get 50MB
|
||||||
|
if isPremiumUser(r) {
|
||||||
|
return middleware.Size50MB
|
||||||
|
}
|
||||||
|
// Free users get 5MB
|
||||||
|
return middleware.Size5MB
|
||||||
|
}
|
||||||
|
|
||||||
|
router.Use(sizeLimiter.MiddlewareWithCustomSize(sizeFunc))
|
||||||
|
```
|
||||||
|
|
||||||
|
**By Content-Type:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
sizeFunc := func(r *http.Request) int64 {
|
||||||
|
contentType := r.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.Contains(contentType, "multipart/form-data"):
|
||||||
|
return middleware.Size50MB // File uploads
|
||||||
|
case strings.Contains(contentType, "application/json"):
|
||||||
|
return middleware.Size1MB // JSON APIs
|
||||||
|
default:
|
||||||
|
return middleware.Size10MB // Default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Error Response
|
||||||
|
|
||||||
|
When size limit exceeded:
|
||||||
|
|
||||||
|
```http
|
||||||
|
HTTP/1.1 413 Request Entity Too Large
|
||||||
|
X-Max-Request-Size: 10485760
|
||||||
|
|
||||||
|
http: request body too large
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complete Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// API routes: 1MB limit
|
||||||
|
api := router.PathPrefix("/api").Subrouter()
|
||||||
|
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
|
||||||
|
api.Use(apiLimiter.Middleware)
|
||||||
|
api.HandleFunc("/users", createUserHandler).Methods("POST")
|
||||||
|
|
||||||
|
// Upload routes: 50MB limit
|
||||||
|
upload := router.PathPrefix("/upload").Subrouter()
|
||||||
|
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
|
||||||
|
upload.Use(uploadLimiter.Middleware)
|
||||||
|
upload.HandleFunc("/file", uploadFileHandler).Methods("POST")
|
||||||
|
|
||||||
|
log.Fatal(http.ListenAndServe(":8080", router))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Input Sanitization
|
||||||
|
|
||||||
|
Protect against XSS, injection attacks, and malicious input.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
|
||||||
|
// Default sanitizer (safe defaults)
|
||||||
|
sanitizer := middleware.DefaultSanitizer()
|
||||||
|
router.Use(sanitizer.Middleware)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sanitizer Types
|
||||||
|
|
||||||
|
**Default Sanitizer (Recommended):**
|
||||||
|
|
||||||
|
```go
|
||||||
|
sanitizer := middleware.DefaultSanitizer()
|
||||||
|
// ✓ Escapes HTML entities
|
||||||
|
// ✓ Removes null bytes
|
||||||
|
// ✓ Removes control characters
|
||||||
|
// ✓ Blocks XSS patterns (script tags, event handlers)
|
||||||
|
// ✗ Does not strip HTML (allows legitimate content)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Strict Sanitizer:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
sanitizer := middleware.StrictSanitizer()
|
||||||
|
// ✓ All default features
|
||||||
|
// ✓ Strips ALL HTML tags
|
||||||
|
// ✓ Max string length: 10,000 chars
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
sanitizer := &middleware.Sanitizer{
|
||||||
|
StripHTML: true, // Remove HTML tags
|
||||||
|
EscapeHTML: false, // Don't escape (already stripped)
|
||||||
|
RemoveNullBytes: true, // Remove \x00
|
||||||
|
RemoveControlChars: true, // Remove dangerous control chars
|
||||||
|
MaxStringLength: 5000, // Limit to 5000 chars
|
||||||
|
|
||||||
|
// Block patterns (regex)
|
||||||
|
BlockPatterns: []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)<script`),
|
||||||
|
regexp.MustCompile(`(?i)javascript:`),
|
||||||
|
},
|
||||||
|
|
||||||
|
// Custom sanitization function
|
||||||
|
CustomSanitizer: func(s string) string {
|
||||||
|
// Your custom logic
|
||||||
|
return strings.ToLower(s)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router.Use(sanitizer.Middleware)
|
||||||
|
```
|
||||||
|
|
||||||
|
### What Gets Sanitized
|
||||||
|
|
||||||
|
**Automatic (via middleware):**
|
||||||
|
- Query parameters
|
||||||
|
- Headers (User-Agent, Referer, X-Forwarded-For, X-Real-IP)
|
||||||
|
|
||||||
|
**Manual (in your handler):**
|
||||||
|
- Request body (JSON, form data)
|
||||||
|
- Database queries
|
||||||
|
- File names
|
||||||
|
|
||||||
|
### Manual Sanitization
|
||||||
|
|
||||||
|
**String Values:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
sanitizer := middleware.DefaultSanitizer()
|
||||||
|
|
||||||
|
// Sanitize user input
|
||||||
|
username := sanitizer.Sanitize(r.FormValue("username"))
|
||||||
|
email := sanitizer.Sanitize(r.FormValue("email"))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Map/JSON Data:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
var data map[string]interface{}
|
||||||
|
json.Unmarshal(body, &data)
|
||||||
|
|
||||||
|
// Sanitize all string values recursively
|
||||||
|
sanitizedData := sanitizer.SanitizeMap(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Nested Structures:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
Bio string
|
||||||
|
Profile map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// After unmarshaling
|
||||||
|
user.Name = sanitizer.Sanitize(user.Name)
|
||||||
|
user.Email = sanitizer.Sanitize(user.Email)
|
||||||
|
user.Bio = sanitizer.Sanitize(user.Bio)
|
||||||
|
user.Profile = sanitizer.SanitizeMap(user.Profile)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Specialized Sanitizers
|
||||||
|
|
||||||
|
**Filenames:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
|
||||||
|
filename := middleware.SanitizeFilename(uploadedFilename)
|
||||||
|
// Removes: .., /, \, null bytes
|
||||||
|
// Limits: 255 characters
|
||||||
|
```
|
||||||
|
|
||||||
|
**Emails:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
email := middleware.SanitizeEmail(" USER@EXAMPLE.COM ")
|
||||||
|
// Result: "user@example.com"
|
||||||
|
// Trims, lowercases, removes null bytes
|
||||||
|
```
|
||||||
|
|
||||||
|
**URLs:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
url := middleware.SanitizeURL(userInput)
|
||||||
|
// Blocks: javascript:, data: protocols
|
||||||
|
// Removes: null bytes
|
||||||
|
```
|
||||||
|
|
||||||
|
### Blocked Patterns (Default)
|
||||||
|
|
||||||
|
The default sanitizer blocks:
|
||||||
|
|
||||||
|
1. **Script tags**: `<script>...</script>`
|
||||||
|
2. **JavaScript protocol**: `javascript:alert(1)`
|
||||||
|
3. **Event handlers**: `onclick="..."`, `onerror="..."`
|
||||||
|
4. **Iframes**: `<iframe src="...">`
|
||||||
|
5. **Objects**: `<object data="...">`
|
||||||
|
6. **Embeds**: `<embed src="...">`
|
||||||
|
|
||||||
|
### Security Best Practices
|
||||||
|
|
||||||
|
**1. Layer Defense:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Layer 1: Middleware (query params, headers)
|
||||||
|
router.Use(sanitizer.Middleware)
|
||||||
|
|
||||||
|
// Layer 2: Input validation (in handler)
|
||||||
|
func createUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var user User
|
||||||
|
json.NewDecoder(r.Body).Decode(&user)
|
||||||
|
|
||||||
|
// Sanitize
|
||||||
|
user.Name = sanitizer.Sanitize(user.Name)
|
||||||
|
user.Email = middleware.SanitizeEmail(user.Email)
|
||||||
|
|
||||||
|
// Validate
|
||||||
|
if !isValidEmail(user.Email) {
|
||||||
|
http.Error(w, "Invalid email", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use parameterized queries (prevents SQL injection)
|
||||||
|
db.Exec("INSERT INTO users (name, email) VALUES (?, ?)",
|
||||||
|
user.Name, user.Email)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Context-Aware Sanitization:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// HTML content (user posts, comments)
|
||||||
|
sanitizer := middleware.StrictSanitizer()
|
||||||
|
post.Content = sanitizer.Sanitize(post.Content)
|
||||||
|
|
||||||
|
// Structured data (JSON API)
|
||||||
|
sanitizer := middleware.DefaultSanitizer()
|
||||||
|
data = sanitizer.SanitizeMap(jsonData)
|
||||||
|
|
||||||
|
// Search queries (preserve special chars)
|
||||||
|
query = middleware.SanitizeFilename(searchTerm) // Light sanitization
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Output Encoding:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// When rendering HTML
|
||||||
|
import "html/template"
|
||||||
|
|
||||||
|
tmpl := template.Must(template.New("page").Parse(`
|
||||||
|
<h1>{{.Title}}</h1>
|
||||||
|
<p>{{.Content}}</p>
|
||||||
|
`))
|
||||||
|
|
||||||
|
// template.HTML automatically escapes
|
||||||
|
tmpl.Execute(w, data)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complete Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Apply sanitization middleware
|
||||||
|
sanitizer := middleware.DefaultSanitizer()
|
||||||
|
router.Use(sanitizer.Middleware)
|
||||||
|
|
||||||
|
router.HandleFunc("/api/users", createUserHandler).Methods("POST")
|
||||||
|
|
||||||
|
log.Fatal(http.ListenAndServe(":8080", router))
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
sanitizer := middleware.DefaultSanitizer()
|
||||||
|
|
||||||
|
var user struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Bio string `json:"bio"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
|
||||||
|
http.Error(w, "Invalid JSON", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize inputs
|
||||||
|
user.Name = sanitizer.Sanitize(user.Name)
|
||||||
|
user.Email = middleware.SanitizeEmail(user.Email)
|
||||||
|
user.Bio = sanitizer.Sanitize(user.Bio)
|
||||||
|
|
||||||
|
// Validate
|
||||||
|
if len(user.Name) == 0 || len(user.Email) == 0 {
|
||||||
|
http.Error(w, "Name and email required", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save to database (use parameterized queries!)
|
||||||
|
// db.Exec("INSERT INTO users (name, email, bio) VALUES (?, ?, ?)",
|
||||||
|
// user.Name, user.Email, user.Bio)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"status": "created",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing Sanitization
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test XSS prevention
|
||||||
|
curl -X POST http://localhost:8080/api/users \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"name": "<script>alert(1)</script>John",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"bio": "My bio with <iframe src=\"evil.com\"></iframe>"
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Script tags and iframes should be removed
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
|
||||||
|
- **Overhead**: <1ms per request for typical payloads
|
||||||
|
- **Regex compilation**: Done once at initialization
|
||||||
|
- **Safe for production**: Minimal performance impact
|
||||||
|
|||||||
251
pkg/middleware/sanitize.go
Normal file
251
pkg/middleware/sanitize.go
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"html"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sanitizer provides input sanitization beyond SQL injection protection
|
||||||
|
type Sanitizer struct {
|
||||||
|
// StripHTML removes HTML tags from input
|
||||||
|
StripHTML bool
|
||||||
|
|
||||||
|
// EscapeHTML escapes HTML entities
|
||||||
|
EscapeHTML bool
|
||||||
|
|
||||||
|
// RemoveNullBytes removes null bytes from input
|
||||||
|
RemoveNullBytes bool
|
||||||
|
|
||||||
|
// RemoveControlChars removes control characters (except newline, carriage return, tab)
|
||||||
|
RemoveControlChars bool
|
||||||
|
|
||||||
|
// MaxStringLength limits individual string field length (0 = no limit)
|
||||||
|
MaxStringLength int
|
||||||
|
|
||||||
|
// BlockPatterns are regex patterns to block (e.g., script tags, SQL keywords)
|
||||||
|
BlockPatterns []*regexp.Regexp
|
||||||
|
|
||||||
|
// Custom sanitization function
|
||||||
|
CustomSanitizer func(string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSanitizer returns a sanitizer with secure defaults
|
||||||
|
func DefaultSanitizer() *Sanitizer {
|
||||||
|
return &Sanitizer{
|
||||||
|
StripHTML: false, // Don't strip by default (breaks legitimate HTML content)
|
||||||
|
EscapeHTML: true, // Escape HTML entities to prevent XSS
|
||||||
|
RemoveNullBytes: true, // Remove null bytes (security best practice)
|
||||||
|
RemoveControlChars: true, // Remove dangerous control characters
|
||||||
|
MaxStringLength: 0, // No limit by default
|
||||||
|
|
||||||
|
// Block common XSS and injection patterns
|
||||||
|
BlockPatterns: []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`), // Script tags
|
||||||
|
regexp.MustCompile(`(?i)javascript:`), // JavaScript protocol
|
||||||
|
regexp.MustCompile(`(?i)on\w+\s*=`), // Event handlers (onclick, onerror, etc.)
|
||||||
|
regexp.MustCompile(`(?i)<iframe[^>]*>`), // Iframes
|
||||||
|
regexp.MustCompile(`(?i)<object[^>]*>`), // Objects
|
||||||
|
regexp.MustCompile(`(?i)<embed[^>]*>`), // Embeds
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StrictSanitizer returns a sanitizer with very strict rules
|
||||||
|
func StrictSanitizer() *Sanitizer {
|
||||||
|
s := DefaultSanitizer()
|
||||||
|
s.StripHTML = true
|
||||||
|
s.MaxStringLength = 10000
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize sanitizes a string value
|
||||||
|
func (s *Sanitizer) Sanitize(value string) string {
|
||||||
|
if value == "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove null bytes
|
||||||
|
if s.RemoveNullBytes {
|
||||||
|
value = strings.ReplaceAll(value, "\x00", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove control characters
|
||||||
|
if s.RemoveControlChars {
|
||||||
|
value = removeControlCharacters(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check block patterns
|
||||||
|
for _, pattern := range s.BlockPatterns {
|
||||||
|
if pattern.MatchString(value) {
|
||||||
|
// Replace matched pattern with empty string
|
||||||
|
value = pattern.ReplaceAllString(value, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip HTML tags
|
||||||
|
if s.StripHTML {
|
||||||
|
value = stripHTMLTags(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape HTML entities
|
||||||
|
if s.EscapeHTML && !s.StripHTML {
|
||||||
|
value = html.EscapeString(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply max length
|
||||||
|
if s.MaxStringLength > 0 && len(value) > s.MaxStringLength {
|
||||||
|
value = value[:s.MaxStringLength]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply custom sanitizer
|
||||||
|
if s.CustomSanitizer != nil {
|
||||||
|
value = s.CustomSanitizer(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeMap sanitizes all string values in a map
|
||||||
|
func (s *Sanitizer) SanitizeMap(data map[string]interface{}) map[string]interface{} {
|
||||||
|
result := make(map[string]interface{})
|
||||||
|
for key, value := range data {
|
||||||
|
result[key] = s.sanitizeValue(value)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeValue recursively sanitizes values
|
||||||
|
func (s *Sanitizer) sanitizeValue(value interface{}) interface{} {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
return s.Sanitize(v)
|
||||||
|
case map[string]interface{}:
|
||||||
|
return s.SanitizeMap(v)
|
||||||
|
case []interface{}:
|
||||||
|
result := make([]interface{}, len(v))
|
||||||
|
for i, item := range v {
|
||||||
|
result[i] = s.sanitizeValue(item)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns an HTTP middleware that sanitizes request headers and query params
|
||||||
|
// Note: Body sanitization should be done at the application level after parsing
|
||||||
|
func (s *Sanitizer) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Sanitize query parameters
|
||||||
|
if r.URL.RawQuery != "" {
|
||||||
|
q := r.URL.Query()
|
||||||
|
sanitized := false
|
||||||
|
for key, values := range q {
|
||||||
|
for i, value := range values {
|
||||||
|
sanitizedValue := s.Sanitize(value)
|
||||||
|
if sanitizedValue != value {
|
||||||
|
values[i] = sanitizedValue
|
||||||
|
sanitized = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sanitized {
|
||||||
|
q[key] = values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sanitized {
|
||||||
|
r.URL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize specific headers (User-Agent, Referer, etc.)
|
||||||
|
dangerousHeaders := []string{
|
||||||
|
"User-Agent",
|
||||||
|
"Referer",
|
||||||
|
"X-Forwarded-For",
|
||||||
|
"X-Real-IP",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, header := range dangerousHeaders {
|
||||||
|
if value := r.Header.Get(header); value != "" {
|
||||||
|
sanitized := s.Sanitize(value)
|
||||||
|
if sanitized != value {
|
||||||
|
r.Header.Set(header, sanitized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
// removeControlCharacters removes control characters except \n, \r, \t
|
||||||
|
func removeControlCharacters(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
for _, r := range s {
|
||||||
|
// Keep newline, carriage return, tab, and non-control characters
|
||||||
|
if r == '\n' || r == '\r' || r == '\t' || r >= 32 {
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripHTMLTags removes HTML tags from a string
|
||||||
|
func stripHTMLTags(s string) string {
|
||||||
|
// Simple regex to remove HTML tags
|
||||||
|
re := regexp.MustCompile(`<[^>]*>`)
|
||||||
|
return re.ReplaceAllString(s, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common sanitization patterns
|
||||||
|
|
||||||
|
// SanitizeFilename sanitizes a filename
|
||||||
|
func SanitizeFilename(filename string) string {
|
||||||
|
// Remove path traversal attempts
|
||||||
|
filename = strings.ReplaceAll(filename, "..", "")
|
||||||
|
filename = strings.ReplaceAll(filename, "/", "")
|
||||||
|
filename = strings.ReplaceAll(filename, "\\", "")
|
||||||
|
|
||||||
|
// Remove null bytes
|
||||||
|
filename = strings.ReplaceAll(filename, "\x00", "")
|
||||||
|
|
||||||
|
// Limit length
|
||||||
|
if len(filename) > 255 {
|
||||||
|
filename = filename[:255]
|
||||||
|
}
|
||||||
|
|
||||||
|
return filename
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeEmail performs basic email sanitization
|
||||||
|
func SanitizeEmail(email string) string {
|
||||||
|
email = strings.TrimSpace(strings.ToLower(email))
|
||||||
|
|
||||||
|
// Remove dangerous characters
|
||||||
|
email = strings.ReplaceAll(email, "\x00", "")
|
||||||
|
email = removeControlCharacters(email)
|
||||||
|
|
||||||
|
return email
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeURL performs basic URL sanitization
|
||||||
|
func SanitizeURL(url string) string {
|
||||||
|
url = strings.TrimSpace(url)
|
||||||
|
|
||||||
|
// Remove null bytes
|
||||||
|
url = strings.ReplaceAll(url, "\x00", "")
|
||||||
|
|
||||||
|
// Block javascript: and data: protocols
|
||||||
|
if strings.HasPrefix(strings.ToLower(url), "javascript:") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(strings.ToLower(url), "data:") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return url
|
||||||
|
}
|
||||||
70
pkg/middleware/sizelimit.go
Normal file
70
pkg/middleware/sizelimit.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultMaxRequestSize is the default maximum request body size (10MB)
|
||||||
|
DefaultMaxRequestSize = 10 * 1024 * 1024 // 10MB
|
||||||
|
|
||||||
|
// MaxRequestSizeHeader is the header name for max request size
|
||||||
|
MaxRequestSizeHeader = "X-Max-Request-Size"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestSizeLimiter limits the size of request bodies
|
||||||
|
type RequestSizeLimiter struct {
|
||||||
|
maxSize int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequestSizeLimiter creates a new request size limiter
|
||||||
|
// maxSize is in bytes. If 0, uses DefaultMaxRequestSize (10MB)
|
||||||
|
func NewRequestSizeLimiter(maxSize int64) *RequestSizeLimiter {
|
||||||
|
if maxSize <= 0 {
|
||||||
|
maxSize = DefaultMaxRequestSize
|
||||||
|
}
|
||||||
|
return &RequestSizeLimiter{
|
||||||
|
maxSize: maxSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns an HTTP middleware that enforces request size limits
|
||||||
|
func (rsl *RequestSizeLimiter) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set max bytes reader on the request body
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, rsl.maxSize)
|
||||||
|
|
||||||
|
// Add informational header
|
||||||
|
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", rsl.maxSize))
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// MiddlewareWithCustomSize returns middleware with a custom size limit function
|
||||||
|
// This allows different size limits based on the request
|
||||||
|
func (rsl *RequestSizeLimiter) MiddlewareWithCustomSize(sizeFunc func(*http.Request) int64) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
maxSize := sizeFunc(r)
|
||||||
|
if maxSize <= 0 {
|
||||||
|
maxSize = rsl.maxSize
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, maxSize)
|
||||||
|
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", maxSize))
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common size limits
|
||||||
|
const (
|
||||||
|
Size1MB = 1 * 1024 * 1024
|
||||||
|
Size5MB = 5 * 1024 * 1024
|
||||||
|
Size10MB = 10 * 1024 * 1024
|
||||||
|
Size50MB = 50 * 1024 * 1024
|
||||||
|
Size100MB = 100 * 1024 * 1024
|
||||||
|
)
|
||||||
493
pkg/server/README.md
Normal file
493
pkg/server/README.md
Normal file
@ -0,0 +1,493 @@
|
|||||||
|
# Server Package
|
||||||
|
|
||||||
|
Graceful HTTP server with request draining and shutdown coordination.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
|
|
||||||
|
// Create server
|
||||||
|
srv := server.NewGracefulServer(server.Config{
|
||||||
|
Addr: ":8080",
|
||||||
|
Handler: router,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start server (blocks until shutdown signal)
|
||||||
|
if err := srv.ListenAndServe(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
✅ Graceful shutdown on SIGINT/SIGTERM
|
||||||
|
✅ Request draining (waits for in-flight requests)
|
||||||
|
✅ Automatic request rejection during shutdown
|
||||||
|
✅ Health and readiness endpoints
|
||||||
|
✅ Shutdown callbacks for cleanup
|
||||||
|
✅ Configurable timeouts
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
config := server.Config{
|
||||||
|
// Server address
|
||||||
|
Addr: ":8080",
|
||||||
|
|
||||||
|
// HTTP handler
|
||||||
|
Handler: myRouter,
|
||||||
|
|
||||||
|
// Maximum time for graceful shutdown (default: 30s)
|
||||||
|
ShutdownTimeout: 30 * time.Second,
|
||||||
|
|
||||||
|
// Time to wait for in-flight requests (default: 25s)
|
||||||
|
DrainTimeout: 25 * time.Second,
|
||||||
|
|
||||||
|
// Request read timeout (default: 10s)
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
|
||||||
|
// Response write timeout (default: 10s)
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
|
||||||
|
// Idle connection timeout (default: 120s)
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := server.NewGracefulServer(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Shutdown Behavior
|
||||||
|
|
||||||
|
**Signal received (SIGINT/SIGTERM):**
|
||||||
|
|
||||||
|
1. **Mark as shutting down** - New requests get 503
|
||||||
|
2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests
|
||||||
|
3. **Shutdown server** - Close listeners and connections
|
||||||
|
4. **Execute callbacks** - Run registered cleanup functions
|
||||||
|
|
||||||
|
```
|
||||||
|
Time Event
|
||||||
|
─────────────────────────────────────────
|
||||||
|
0s Signal received: SIGTERM
|
||||||
|
├─ Mark as shutting down
|
||||||
|
├─ Reject new requests (503)
|
||||||
|
└─ Start draining...
|
||||||
|
|
||||||
|
1s In-flight: 50 requests
|
||||||
|
2s In-flight: 32 requests
|
||||||
|
3s In-flight: 12 requests
|
||||||
|
4s In-flight: 3 requests
|
||||||
|
5s In-flight: 0 requests ✓
|
||||||
|
└─ All requests drained
|
||||||
|
|
||||||
|
5s Execute shutdown callbacks
|
||||||
|
6s Shutdown complete
|
||||||
|
```
|
||||||
|
|
||||||
|
## Health Checks
|
||||||
|
|
||||||
|
### Health Endpoint
|
||||||
|
|
||||||
|
Returns 200 when healthy, 503 when shutting down:
|
||||||
|
|
||||||
|
```go
|
||||||
|
router.HandleFunc("/health", srv.HealthCheckHandler())
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (healthy):**
|
||||||
|
```json
|
||||||
|
{"status":"healthy"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (shutting down):**
|
||||||
|
```json
|
||||||
|
{"status":"shutting_down"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Readiness Endpoint
|
||||||
|
|
||||||
|
Includes in-flight request count:
|
||||||
|
|
||||||
|
```go
|
||||||
|
router.HandleFunc("/ready", srv.ReadinessHandler())
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{"ready":true,"in_flight_requests":12}
|
||||||
|
```
|
||||||
|
|
||||||
|
**During shutdown:**
|
||||||
|
```json
|
||||||
|
{"ready":false,"reason":"shutting_down"}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Shutdown Callbacks
|
||||||
|
|
||||||
|
Register cleanup functions to run during shutdown:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Close database
|
||||||
|
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
logger.Info("Closing database connection...")
|
||||||
|
return db.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Flush metrics
|
||||||
|
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
logger.Info("Flushing metrics...")
|
||||||
|
return metricsProvider.Flush(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Close cache
|
||||||
|
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
logger.Info("Closing cache...")
|
||||||
|
return cache.Close()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Initialize metrics
|
||||||
|
metricsProvider := metrics.NewPrometheusProvider()
|
||||||
|
metrics.SetProvider(metricsProvider)
|
||||||
|
|
||||||
|
// Create router
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Apply middleware
|
||||||
|
rateLimiter := middleware.NewRateLimiter(100, 20)
|
||||||
|
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB)
|
||||||
|
sanitizer := middleware.DefaultSanitizer()
|
||||||
|
|
||||||
|
router.Use(rateLimiter.Middleware)
|
||||||
|
router.Use(sizeLimiter.Middleware)
|
||||||
|
router.Use(sanitizer.Middleware)
|
||||||
|
router.Use(metricsProvider.Middleware)
|
||||||
|
|
||||||
|
// API routes
|
||||||
|
router.HandleFunc("/api/data", dataHandler)
|
||||||
|
|
||||||
|
// Create graceful server
|
||||||
|
srv := server.NewGracefulServer(server.Config{
|
||||||
|
Addr: ":8080",
|
||||||
|
Handler: router,
|
||||||
|
ShutdownTimeout: 30 * time.Second,
|
||||||
|
DrainTimeout: 25 * time.Second,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Health checks
|
||||||
|
router.HandleFunc("/health", srv.HealthCheckHandler())
|
||||||
|
router.HandleFunc("/ready", srv.ReadinessHandler())
|
||||||
|
|
||||||
|
// Metrics endpoint
|
||||||
|
router.Handle("/metrics", metricsProvider.Handler())
|
||||||
|
|
||||||
|
// Register shutdown callbacks
|
||||||
|
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
log.Println("Cleanup: Flushing metrics...")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
log.Println("Cleanup: Closing database...")
|
||||||
|
// return db.Close()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start server (blocks until shutdown)
|
||||||
|
log.Printf("Starting server on :8080")
|
||||||
|
if err := srv.ListenAndServe(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for shutdown to complete
|
||||||
|
srv.Wait()
|
||||||
|
log.Println("Server stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Your handler logic
|
||||||
|
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(`{"message":"success"}`))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Kubernetes Integration
|
||||||
|
|
||||||
|
### Deployment with Probes
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: myapp
|
||||||
|
spec:
|
||||||
|
replicas: 3
|
||||||
|
template:
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: app
|
||||||
|
image: myapp:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 8080
|
||||||
|
|
||||||
|
# Liveness probe - is app running?
|
||||||
|
livenessProbe:
|
||||||
|
httpGet:
|
||||||
|
path: /health
|
||||||
|
port: 8080
|
||||||
|
initialDelaySeconds: 10
|
||||||
|
periodSeconds: 10
|
||||||
|
timeoutSeconds: 5
|
||||||
|
|
||||||
|
# Readiness probe - can app handle traffic?
|
||||||
|
readinessProbe:
|
||||||
|
httpGet:
|
||||||
|
path: /ready
|
||||||
|
port: 8080
|
||||||
|
initialDelaySeconds: 5
|
||||||
|
periodSeconds: 5
|
||||||
|
timeoutSeconds: 3
|
||||||
|
|
||||||
|
# Graceful shutdown
|
||||||
|
lifecycle:
|
||||||
|
preStop:
|
||||||
|
exec:
|
||||||
|
command: ["/bin/sh", "-c", "sleep 5"]
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
env:
|
||||||
|
- name: SHUTDOWN_TIMEOUT
|
||||||
|
value: "30"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Service
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: myapp
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: myapp
|
||||||
|
ports:
|
||||||
|
- port: 80
|
||||||
|
targetPort: 8080
|
||||||
|
type: LoadBalancer
|
||||||
|
```
|
||||||
|
|
||||||
|
## Docker Compose
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: '3.8'
|
||||||
|
services:
|
||||||
|
app:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8080:8080"
|
||||||
|
environment:
|
||||||
|
- SHUTDOWN_TIMEOUT=30
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
stop_grace_period: 35s # Slightly longer than shutdown timeout
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Graceful Shutdown
|
||||||
|
|
||||||
|
### Test Script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Start server in background
|
||||||
|
./myapp &
|
||||||
|
SERVER_PID=$!
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# Send some requests
|
||||||
|
for i in {1..10}; do
|
||||||
|
curl http://localhost:8080/api/data &
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait a bit
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Send shutdown signal
|
||||||
|
kill -TERM $SERVER_PID
|
||||||
|
|
||||||
|
# Try to send more requests (should get 503)
|
||||||
|
curl -v http://localhost:8080/api/data
|
||||||
|
|
||||||
|
# Wait for server to stop
|
||||||
|
wait $SERVER_PID
|
||||||
|
echo "Server stopped gracefully"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Expected Output
|
||||||
|
|
||||||
|
```
|
||||||
|
Starting server on :8080
|
||||||
|
Received signal: terminated, initiating graceful shutdown
|
||||||
|
Starting graceful shutdown...
|
||||||
|
Waiting for 8 in-flight requests to complete...
|
||||||
|
Waiting for 4 in-flight requests to complete...
|
||||||
|
Waiting for 1 in-flight requests to complete...
|
||||||
|
All requests drained in 2.3s
|
||||||
|
Cleanup: Flushing metrics...
|
||||||
|
Cleanup: Closing database...
|
||||||
|
Shutting down HTTP server...
|
||||||
|
Graceful shutdown complete
|
||||||
|
Server stopped
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring In-Flight Requests
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get current in-flight count
|
||||||
|
count := srv.InFlightRequests()
|
||||||
|
fmt.Printf("In-flight requests: %d\n", count)
|
||||||
|
|
||||||
|
// Check if shutting down
|
||||||
|
if srv.IsShuttingDown() {
|
||||||
|
fmt.Println("Server is shutting down")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Custom Shutdown Logic
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Implement custom shutdown
|
||||||
|
go func() {
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
|
<-sigChan
|
||||||
|
log.Println("Shutdown signal received")
|
||||||
|
|
||||||
|
// Custom pre-shutdown logic
|
||||||
|
log.Println("Running custom cleanup...")
|
||||||
|
|
||||||
|
// Shutdown with callbacks
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := srv.ShutdownWithCallbacks(ctx); err != nil {
|
||||||
|
log.Printf("Shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
srv.server.ListenAndServe()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Servers
|
||||||
|
|
||||||
|
```go
|
||||||
|
// HTTP server
|
||||||
|
httpSrv := server.NewGracefulServer(server.Config{
|
||||||
|
Addr: ":8080",
|
||||||
|
Handler: httpRouter,
|
||||||
|
})
|
||||||
|
|
||||||
|
// HTTPS server
|
||||||
|
httpsSrv := server.NewGracefulServer(server.Config{
|
||||||
|
Addr: ":8443",
|
||||||
|
Handler: httpsRouter,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start both
|
||||||
|
go httpSrv.ListenAndServe()
|
||||||
|
go httpsSrv.ListenAndServe()
|
||||||
|
|
||||||
|
// Shutdown both on signal
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, os.Interrupt)
|
||||||
|
<-sigChan
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
httpSrv.Shutdown(ctx)
|
||||||
|
httpsSrv.Shutdown(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Set appropriate timeouts**
|
||||||
|
- `DrainTimeout` < `ShutdownTimeout`
|
||||||
|
- `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds`
|
||||||
|
|
||||||
|
2. **Register cleanup callbacks** for:
|
||||||
|
- Database connections
|
||||||
|
- Message queues
|
||||||
|
- Metrics flushing
|
||||||
|
- Cache shutdown
|
||||||
|
- Background workers
|
||||||
|
|
||||||
|
3. **Health checks**
|
||||||
|
- Use `/health` for liveness (is app alive?)
|
||||||
|
- Use `/ready` for readiness (can app serve traffic?)
|
||||||
|
|
||||||
|
4. **Load balancer considerations**
|
||||||
|
- Set `preStop` hook in Kubernetes (5-10s delay)
|
||||||
|
- Allows load balancer to deregister before shutdown
|
||||||
|
|
||||||
|
5. **Monitoring**
|
||||||
|
- Track in-flight requests in metrics
|
||||||
|
- Alert on slow drains
|
||||||
|
- Monitor shutdown duration
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Shutdown Takes Too Long
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Increase drain timeout
|
||||||
|
config.DrainTimeout = 60 * time.Second
|
||||||
|
```
|
||||||
|
|
||||||
|
### Requests Still Timing Out
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Increase write timeout
|
||||||
|
config.WriteTimeout = 30 * time.Second
|
||||||
|
```
|
||||||
|
|
||||||
|
### Force Shutdown Not Working
|
||||||
|
|
||||||
|
The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed.
|
||||||
|
|
||||||
|
### Debugging Shutdown
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Enable debug logging
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
|
||||||
|
logger.SetLevel("debug")
|
||||||
|
```
|
||||||
296
pkg/server/shutdown.go
Normal file
296
pkg/server/shutdown.go
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GracefulServer wraps http.Server with graceful shutdown capabilities
|
||||||
|
type GracefulServer struct {
|
||||||
|
server *http.Server
|
||||||
|
shutdownTimeout time.Duration
|
||||||
|
drainTimeout time.Duration
|
||||||
|
inFlightRequests atomic.Int64
|
||||||
|
isShuttingDown atomic.Bool
|
||||||
|
shutdownOnce sync.Once
|
||||||
|
shutdownComplete chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds configuration for the graceful server
|
||||||
|
type Config struct {
|
||||||
|
// Addr is the server address (e.g., ":8080")
|
||||||
|
Addr string
|
||||||
|
|
||||||
|
// Handler is the HTTP handler
|
||||||
|
Handler http.Handler
|
||||||
|
|
||||||
|
// ShutdownTimeout is the maximum time to wait for graceful shutdown
|
||||||
|
// Default: 30 seconds
|
||||||
|
ShutdownTimeout time.Duration
|
||||||
|
|
||||||
|
// DrainTimeout is the time to wait for in-flight requests to complete
|
||||||
|
// before forcing shutdown. Default: 25 seconds
|
||||||
|
DrainTimeout time.Duration
|
||||||
|
|
||||||
|
// ReadTimeout is the maximum duration for reading the entire request
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
|
||||||
|
// WriteTimeout is the maximum duration before timing out writes of the response
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
|
||||||
|
// IdleTimeout is the maximum amount of time to wait for the next request
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGracefulServer creates a new graceful server
|
||||||
|
func NewGracefulServer(config Config) *GracefulServer {
|
||||||
|
if config.ShutdownTimeout == 0 {
|
||||||
|
config.ShutdownTimeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
if config.DrainTimeout == 0 {
|
||||||
|
config.DrainTimeout = 25 * time.Second
|
||||||
|
}
|
||||||
|
if config.ReadTimeout == 0 {
|
||||||
|
config.ReadTimeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
if config.WriteTimeout == 0 {
|
||||||
|
config.WriteTimeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
if config.IdleTimeout == 0 {
|
||||||
|
config.IdleTimeout = 120 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
gs := &GracefulServer{
|
||||||
|
server: &http.Server{
|
||||||
|
Addr: config.Addr,
|
||||||
|
Handler: config.Handler,
|
||||||
|
ReadTimeout: config.ReadTimeout,
|
||||||
|
WriteTimeout: config.WriteTimeout,
|
||||||
|
IdleTimeout: config.IdleTimeout,
|
||||||
|
},
|
||||||
|
shutdownTimeout: config.ShutdownTimeout,
|
||||||
|
drainTimeout: config.DrainTimeout,
|
||||||
|
shutdownComplete: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return gs
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
|
||||||
|
func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check if shutting down
|
||||||
|
if gs.isShuttingDown.Load() {
|
||||||
|
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment in-flight counter
|
||||||
|
gs.inFlightRequests.Add(1)
|
||||||
|
defer gs.inFlightRequests.Add(-1)
|
||||||
|
|
||||||
|
// Serve the request
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServe starts the server and handles graceful shutdown
|
||||||
|
func (gs *GracefulServer) ListenAndServe() error {
|
||||||
|
// Wrap handler with request tracking
|
||||||
|
gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler)
|
||||||
|
|
||||||
|
// Start server in goroutine
|
||||||
|
serverErr := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
logger.Info("Starting server on %s", gs.server.Addr)
|
||||||
|
if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
serverErr <- err
|
||||||
|
}
|
||||||
|
close(serverErr)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for interrupt signal
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-serverErr:
|
||||||
|
return err
|
||||||
|
case sig := <-sigChan:
|
||||||
|
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
|
||||||
|
return gs.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown performs graceful shutdown with request draining
|
||||||
|
func (gs *GracefulServer) Shutdown(ctx context.Context) error {
|
||||||
|
var shutdownErr error
|
||||||
|
|
||||||
|
gs.shutdownOnce.Do(func() {
|
||||||
|
logger.Info("Starting graceful shutdown...")
|
||||||
|
|
||||||
|
// Mark as shutting down (new requests will be rejected)
|
||||||
|
gs.isShuttingDown.Store(true)
|
||||||
|
|
||||||
|
// Create context with timeout
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Wait for in-flight requests to complete (with drain timeout)
|
||||||
|
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
|
||||||
|
defer drainCancel()
|
||||||
|
|
||||||
|
shutdownErr = gs.drainRequests(drainCtx)
|
||||||
|
if shutdownErr != nil {
|
||||||
|
logger.Error("Error draining requests: %v", shutdownErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the server
|
||||||
|
logger.Info("Shutting down HTTP server...")
|
||||||
|
if err := gs.server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
logger.Error("Error shutting down server: %v", err)
|
||||||
|
if shutdownErr == nil {
|
||||||
|
shutdownErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Graceful shutdown complete")
|
||||||
|
close(gs.shutdownComplete)
|
||||||
|
})
|
||||||
|
|
||||||
|
return shutdownErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainRequests waits for in-flight requests to complete
|
||||||
|
func (gs *GracefulServer) drainRequests(ctx context.Context) error {
|
||||||
|
ticker := time.NewTicker(100 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
for {
|
||||||
|
inFlight := gs.inFlightRequests.Load()
|
||||||
|
|
||||||
|
if inFlight == 0 {
|
||||||
|
logger.Info("All requests drained in %v", time.Since(startTime))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
|
||||||
|
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
|
||||||
|
case <-ticker.C:
|
||||||
|
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InFlightRequests returns the current number of in-flight requests
|
||||||
|
func (gs *GracefulServer) InFlightRequests() int64 {
|
||||||
|
return gs.inFlightRequests.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsShuttingDown returns true if the server is shutting down
|
||||||
|
func (gs *GracefulServer) IsShuttingDown() bool {
|
||||||
|
return gs.isShuttingDown.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until shutdown is complete
|
||||||
|
func (gs *GracefulServer) Wait() {
|
||||||
|
<-gs.shutdownComplete
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheckHandler returns a handler that responds to health checks
|
||||||
|
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down
|
||||||
|
func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if gs.IsShuttingDown() {
|
||||||
|
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to write. %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadinessHandler returns a handler for readiness checks
|
||||||
|
// Includes in-flight request count
|
||||||
|
func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if gs.IsShuttingDown() {
|
||||||
|
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inFlight := gs.InFlightRequests()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownCallback is a function called during shutdown
|
||||||
|
type ShutdownCallback func(context.Context) error
|
||||||
|
|
||||||
|
// shutdownCallbacks stores registered shutdown callbacks
|
||||||
|
var (
|
||||||
|
shutdownCallbacks []ShutdownCallback
|
||||||
|
shutdownCallbacksMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterShutdownCallback registers a callback to be called during shutdown
|
||||||
|
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
|
||||||
|
func RegisterShutdownCallback(cb ShutdownCallback) {
|
||||||
|
shutdownCallbacksMu.Lock()
|
||||||
|
defer shutdownCallbacksMu.Unlock()
|
||||||
|
shutdownCallbacks = append(shutdownCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeShutdownCallbacks runs all registered shutdown callbacks
|
||||||
|
func executeShutdownCallbacks(ctx context.Context) error {
|
||||||
|
shutdownCallbacksMu.Lock()
|
||||||
|
callbacks := make([]ShutdownCallback, len(shutdownCallbacks))
|
||||||
|
copy(callbacks, shutdownCallbacks)
|
||||||
|
shutdownCallbacksMu.Unlock()
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
for i, cb := range callbacks {
|
||||||
|
logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks))
|
||||||
|
if err := cb(ctx); err != nil {
|
||||||
|
logger.Error("Shutdown callback %d failed: %v", i+1, err)
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return fmt.Errorf("shutdown callbacks failed: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownWithCallbacks performs shutdown and executes all registered callbacks
|
||||||
|
func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error {
|
||||||
|
// Execute callbacks first
|
||||||
|
if err := executeShutdownCallbacks(ctx); err != nil {
|
||||||
|
logger.Error("Error executing shutdown callbacks: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then shutdown the server
|
||||||
|
return gs.Shutdown(ctx)
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user