From b17241b928ebcb132923d70791ec781e3f72240e Mon Sep 17 00:00:00 2001 From: Hein Date: Sun, 26 Apr 2026 23:25:51 +0200 Subject: [PATCH] feat(auth): track unique tools in access metrics * Add tool tracking to AccessTracker and metrics * Update tests to validate tool tracking functionality * Modify middleware to record tool usage * Enhance observability with tool context * Update UI to display unique tools in metrics --- internal/app/status_test.go | 5 +- internal/auth/access_tracker.go | 11 +++- internal/auth/access_tracker_test.go | 27 ++++++---- internal/auth/keyring_test.go | 50 +++++++++++++++++++ internal/auth/middleware.go | 10 +++- internal/observability/http.go | 5 ++ ui/src/App.svelte | 4 +- .../components/dashboard/DashboardPage.svelte | 7 ++- .../components/dashboard/StatusCards.svelte | 6 ++- ui/src/types.ts | 2 + 10 files changed, 112 insertions(+), 15 deletions(-) diff --git a/internal/app/status_test.go b/internal/app/status_test.go index 20808ae..0736663 100644 --- a/internal/app/status_test.go +++ b/internal/app/status_test.go @@ -33,7 +33,7 @@ func TestStatusSnapshotHidesOAuthLinkWhenDisabled(t *testing.T) { func TestStatusSnapshotShowsTrackedAccess(t *testing.T) { tracker := auth.NewAccessTracker() now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC) - tracker.Record("client-a", "/files", "127.0.0.1:1234", "tester", now) + tracker.Record("client-a", "/files", "127.0.0.1:1234", "tester", "list_projects", now) snapshot := statusSnapshot(buildinfo.Info{Version: "v1.2.3"}, tracker, true, now) @@ -58,6 +58,9 @@ func TestStatusSnapshotShowsTrackedAccess(t *testing.T) { if snapshot.Metrics.UniqueAgents != 1 { t.Fatalf("Metrics.UniqueAgents = %d, want 1", snapshot.Metrics.UniqueAgents) } + if snapshot.Metrics.UniqueTools != 1 { + t.Fatalf("Metrics.UniqueTools = %d, want 1", snapshot.Metrics.UniqueTools) + } } func TestStatusAPIHandlerReturnsJSON(t *testing.T) { diff --git a/internal/auth/access_tracker.go b/internal/auth/access_tracker.go index 85d6e2c..06019ce 100644 --- a/internal/auth/access_tracker.go +++ b/internal/auth/access_tracker.go @@ -22,6 +22,7 @@ type AccessTracker struct { entries map[string]AccessSnapshot ipCounts map[string]int agentCounts map[string]int + toolCounts map[string]int totalRequests int } @@ -30,10 +31,11 @@ func NewAccessTracker() *AccessTracker { entries: make(map[string]AccessSnapshot), ipCounts: make(map[string]int), agentCounts: make(map[string]int), + toolCounts: make(map[string]int), } } -func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent string, now time.Time) { +func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent, toolName string, now time.Time) { if t == nil || keyID == "" { return } @@ -59,6 +61,9 @@ func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent string, now ti if userAgent != "" { t.agentCounts[userAgent]++ } + if tool := strings.TrimSpace(toolName); tool != "" { + t.toolCounts[tool]++ + } } func normalizeRemoteAddr(value string) string { @@ -121,8 +126,10 @@ type AccessMetrics struct { UniquePrincipals int `json:"unique_principals"` UniqueIPs int `json:"unique_ips"` UniqueAgents int `json:"unique_agents"` + UniqueTools int `json:"unique_tools"` TopIPs []RequestAggregate `json:"top_ips"` TopAgents []RequestAggregate `json:"top_agents"` + TopTools []RequestAggregate `json:"top_tools"` } func (t *AccessTracker) Metrics(topN int) AccessMetrics { @@ -141,8 +148,10 @@ func (t *AccessTracker) Metrics(topN int) AccessMetrics { UniquePrincipals: len(t.entries), UniqueIPs: len(t.ipCounts), UniqueAgents: len(t.agentCounts), + UniqueTools: len(t.toolCounts), TopIPs: topAggregates(t.ipCounts, topN), TopAgents: topAggregates(t.agentCounts, topN), + TopTools: topAggregates(t.toolCounts, topN), } } diff --git a/internal/auth/access_tracker_test.go b/internal/auth/access_tracker_test.go index a174c4a..7c9a7f1 100644 --- a/internal/auth/access_tracker_test.go +++ b/internal/auth/access_tracker_test.go @@ -10,9 +10,9 @@ func TestAccessTrackerRecordAndSnapshot(t *testing.T) { older := time.Date(2026, 4, 4, 10, 0, 0, 0, time.UTC) newer := older.Add(2 * time.Minute) - tracker.Record("client-a", "/files", "10.0.0.1:1234", "agent-a", older) - tracker.Record("client-b", "/mcp", "10.0.0.2:1234", "agent-b", newer) - tracker.Record("client-a", "/files/1", "10.0.0.1:1234", "agent-a2", newer.Add(30*time.Second)) + tracker.Record("client-a", "/files", "10.0.0.1:1234", "agent-a", "", older) + tracker.Record("client-b", "/mcp", "10.0.0.2:1234", "agent-b", "list_projects", newer) + tracker.Record("client-a", "/files/1", "10.0.0.1:1234", "agent-a2", "", newer.Add(30*time.Second)) snap := tracker.Snapshot() if len(snap) != 2 { @@ -39,8 +39,8 @@ func TestAccessTrackerConnectedCount(t *testing.T) { tracker := NewAccessTracker() now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC) - tracker.Record("recent", "/mcp", "", "", now.Add(-2*time.Minute)) - tracker.Record("stale", "/mcp", "", "", now.Add(-11*time.Minute)) + tracker.Record("recent", "/mcp", "", "", "", now.Add(-2*time.Minute)) + tracker.Record("stale", "/mcp", "", "", "", now.Add(-11*time.Minute)) if got := tracker.ConnectedCount(now, 10*time.Minute); got != 1 { t.Fatalf("ConnectedCount() = %d, want 1", got) @@ -51,10 +51,10 @@ func TestAccessTrackerMetrics(t *testing.T) { tracker := NewAccessTracker() now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC) - tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", now) - tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", now.Add(1*time.Second)) - tracker.Record("client-b", "/files", "10.0.0.2:5678", "agent-b", now.Add(2*time.Second)) - tracker.Record("client-c", "/files", "10.0.0.2:5678", "agent-b", now.Add(3*time.Second)) + tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", "list_projects", now) + tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", "list_projects", now.Add(1*time.Second)) + tracker.Record("client-b", "/files", "10.0.0.2:5678", "agent-b", "", now.Add(2*time.Second)) + tracker.Record("client-c", "/files", "10.0.0.2:5678", "agent-b", "search_thoughts", now.Add(3*time.Second)) metrics := tracker.Metrics(5) if metrics.TotalRequests != 4 { @@ -69,6 +69,9 @@ func TestAccessTrackerMetrics(t *testing.T) { if metrics.UniqueAgents != 2 { t.Fatalf("UniqueAgents = %d, want 2", metrics.UniqueAgents) } + if metrics.UniqueTools != 2 { + t.Fatalf("UniqueTools = %d, want 2", metrics.UniqueTools) + } if len(metrics.TopIPs) != 2 { t.Fatalf("len(TopIPs) = %d, want 2", len(metrics.TopIPs)) } @@ -84,4 +87,10 @@ func TestAccessTrackerMetrics(t *testing.T) { if metrics.TopAgents[0].RequestCount != 2 || metrics.TopAgents[1].RequestCount != 2 { t.Fatalf("TopAgents counts = %+v, want both counts to be 2", metrics.TopAgents) } + if len(metrics.TopTools) != 2 { + t.Fatalf("len(TopTools) = %d, want 2", len(metrics.TopTools)) + } + if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 2 { + t.Fatalf("TopTools[0] = %+v, want list_projects with count 2", metrics.TopTools[0]) + } } diff --git a/internal/auth/keyring_test.go b/internal/auth/keyring_test.go index f8ee5a8..3297fee 100644 --- a/internal/auth/keyring_test.go +++ b/internal/auth/keyring_test.go @@ -1,6 +1,8 @@ package auth import ( + "bytes" + "encoding/json" "io" "log/slog" "net/http" @@ -8,6 +10,7 @@ import ( "testing" "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/observability" ) func testLogger() *slog.Logger { @@ -188,3 +191,50 @@ func TestMiddlewareRecordsForwardedRemoteAddr(t *testing.T) { t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99") } } + +func TestMiddlewareRecordsMCPToolUsage(t *testing.T) { + keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) + if err != nil { + t.Fatalf("NewKeyring() error = %v", err) + } + tracker := NewAccessTracker() + logger := testLogger() + + authenticated := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + handler := observability.AccessLog(logger)(authenticated) + + payload := map[string]any{ + "jsonrpc": "2.0", + "id": "1", + "method": "tools/call", + "params": map[string]any{ + "name": "list_projects", + }, + } + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body)) + req.Header.Set("x-brain-key", "secret") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + + metrics := tracker.Metrics(10) + if metrics.UniqueTools != 1 { + t.Fatalf("UniqueTools = %d, want 1", metrics.UniqueTools) + } + if len(metrics.TopTools) != 1 { + t.Fatalf("len(TopTools) = %d, want 1", len(metrics.TopTools)) + } + if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 1 { + t.Fatalf("TopTools[0] = %+v, want list_projects with count 1", metrics.TopTools[0]) + } +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index d67ae14..8df9368 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -9,6 +9,7 @@ import ( "time" "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/observability" "git.warky.dev/wdevs/amcs/internal/requestip" ) @@ -23,7 +24,14 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg } recordAccess := func(r *http.Request, keyID string) { if tracker != nil { - tracker.Record(keyID, r.URL.Path, requestip.FromRequest(r), r.UserAgent(), time.Now()) + tracker.Record( + keyID, + r.URL.Path, + requestip.FromRequest(r), + r.UserAgent(), + observability.MCPToolFromContext(r.Context()), + time.Now(), + ) } } return func(next http.Handler) http.Handler { diff --git a/internal/observability/http.go b/internal/observability/http.go index 064392f..67f2ffc 100644 --- a/internal/observability/http.go +++ b/internal/observability/http.go @@ -106,6 +106,11 @@ func RequestIDFromContext(ctx context.Context) string { return value } +func MCPToolFromContext(ctx context.Context) string { + value, _ := ctx.Value(mcpToolContextKey).(string) + return strings.TrimSpace(value) +} + type statusRecorder struct { http.ResponseWriter status int diff --git a/ui/src/App.svelte b/ui/src/App.svelte index ee9edfa..d9d91d4 100644 --- a/ui/src/App.svelte +++ b/ui/src/App.svelte @@ -128,8 +128,10 @@ unique_principals: raw?.metrics?.unique_principals ?? 0, unique_ips: raw?.metrics?.unique_ips ?? 0, unique_agents: raw?.metrics?.unique_agents ?? 0, + unique_tools: raw?.metrics?.unique_tools ?? 0, top_ips: Array.isArray(raw?.metrics?.top_ips) ? raw.metrics.top_ips : [], - top_agents: Array.isArray(raw?.metrics?.top_agents) ? raw.metrics.top_agents : [] + top_agents: Array.isArray(raw?.metrics?.top_agents) ? raw.metrics.top_agents : [], + top_tools: Array.isArray(raw?.metrics?.top_tools) ? raw.metrics.top_tools : [] } }; } catch (err) { diff --git a/ui/src/components/dashboard/DashboardPage.svelte b/ui/src/components/dashboard/DashboardPage.svelte index 1f1bffa..a756e32 100644 --- a/ui/src/components/dashboard/DashboardPage.svelte +++ b/ui/src/components/dashboard/DashboardPage.svelte @@ -50,7 +50,7 @@ {/if} {#if data} -
+
+
{/if} diff --git a/ui/src/components/dashboard/StatusCards.svelte b/ui/src/components/dashboard/StatusCards.svelte index 77af0a8..32cb06a 100644 --- a/ui/src/components/dashboard/StatusCards.svelte +++ b/ui/src/components/dashboard/StatusCards.svelte @@ -4,7 +4,7 @@ const { data }: { data: StatusResponse } = $props(); -
+

Connected users

{data.connected_count}

@@ -25,6 +25,10 @@

Unique agents

{data.metrics.unique_agents}

+
+

Unique MCP tools

+

{data.metrics.unique_tools}

+

Version

{data.version}

diff --git a/ui/src/types.ts b/ui/src/types.ts index 13704f3..a1221b7 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -17,8 +17,10 @@ export type AccessMetrics = { unique_principals: number; unique_ips: number; unique_agents: number; + unique_tools: number; top_ips: RequestAggregate[]; top_agents: RequestAggregate[]; + top_tools: RequestAggregate[]; }; export type StatusResponse = {