package observability import ( "bytes" "context" "encoding/json" "io" "log/slog" "net/http" "runtime/debug" "strings" "time" "github.com/google/uuid" "git.warky.dev/wdevs/amcs/internal/requestip" ) type contextKey string const requestIDContextKey contextKey = "request_id" const mcpToolContextKey contextKey = "mcp_tool" func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler { for i := len(middlewares) - 1; i >= 0; i-- { h = middlewares[i](h) } return h } func RequestID() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestID := r.Header.Get("X-Request-Id") if requestID == "" { requestID = uuid.NewString() } w.Header().Set("X-Request-Id", requestID) ctx := context.WithValue(r.Context(), requestIDContextKey, requestID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } func Recover(log *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if recovered := recover(); recovered != nil { log.Error("panic recovered", slog.Any("panic", recovered), slog.String("request_id", RequestIDFromContext(r.Context())), slog.String("stack", string(debug.Stack())), ) http.Error(w, "internal server error", http.StatusInternalServerError) } }() next.ServeHTTP(w, r) }) } } func AccessLog(log *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if tool := mcpToolFromRequest(r); tool != "" { r = r.WithContext(context.WithValue(r.Context(), mcpToolContextKey, tool)) } recorder := &statusRecorder{ResponseWriter: w, status: http.StatusOK} started := time.Now() next.ServeHTTP(recorder, r) attrs := []any{ slog.String("request_id", RequestIDFromContext(r.Context())), slog.String("method", r.Method), slog.String("path", r.URL.Path), slog.Int("status", recorder.status), slog.Duration("duration", time.Since(started)), slog.String("remote_addr", requestip.FromRequest(r)), slog.String("mcp_session_id", mcpSessionIDFromRequest(r)), } if tool, _ := r.Context().Value(mcpToolContextKey).(string); strings.TrimSpace(tool) != "" { attrs = append(attrs, slog.String("tool", tool), slog.String("tool_call", tool)) } log.Info("http request", attrs...) }) } } func Timeout(timeout time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { if timeout <= 0 { return next } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() next.ServeHTTP(w, r.WithContext(ctx)) }) } } func RequestIDFromContext(ctx context.Context) string { value, _ := ctx.Value(requestIDContextKey).(string) return value } type statusRecorder struct { http.ResponseWriter status int } func (s *statusRecorder) WriteHeader(statusCode int) { s.status = statusCode s.ResponseWriter.WriteHeader(statusCode) } func mcpToolFromRequest(r *http.Request) string { if r == nil || r.Method != http.MethodPost || !strings.HasPrefix(r.URL.Path, "/mcp") || r.Body == nil { return "" } raw, err := io.ReadAll(r.Body) if err != nil { return "" } r.Body = io.NopCloser(bytes.NewReader(raw)) if len(raw) == 0 { return "" } // Support both single and batch JSON-RPC payloads. if strings.HasPrefix(strings.TrimSpace(string(raw)), "[") { var batch []rpcEnvelope if err := json.Unmarshal(raw, &batch); err != nil { return "" } for _, msg := range batch { if tool := msg.toolName(); tool != "" { return tool } } return "" } var msg rpcEnvelope if err := json.Unmarshal(raw, &msg); err != nil { return "" } return msg.toolName() } func mcpSessionIDFromRequest(r *http.Request) string { if r == nil { return "" } if v := strings.TrimSpace(r.Header.Get("MCP-Session-Id")); v != "" { return v } // Some clients/proxies may propagate the session in query params. for _, key := range []string{"session_id", "sessionId", "mcp_session_id"} { if v := strings.TrimSpace(r.URL.Query().Get(key)); v != "" { return v } } return "" } type rpcEnvelope struct { Method string `json:"method"` Params struct { Name string `json:"name"` } `json:"params"` } func (m rpcEnvelope) toolName() string { if m.Method != "tools/call" { return "" } return strings.TrimSpace(m.Params.Name) }