feat(auth): track unique tools in access metrics
Some checks failed
CI / build-and-test (push) Failing after -31m49s

* Add tool tracking to AccessTracker and metrics
* Update tests to validate tool tracking functionality
* Modify middleware to record tool usage
* Enhance observability with tool context
* Update UI to display unique tools in metrics
This commit is contained in:
2026-04-26 23:25:51 +02:00
parent 63f8dcacb6
commit b17241b928
10 changed files with 112 additions and 15 deletions

View File

@@ -33,7 +33,7 @@ func TestStatusSnapshotHidesOAuthLinkWhenDisabled(t *testing.T) {
func TestStatusSnapshotShowsTrackedAccess(t *testing.T) { func TestStatusSnapshotShowsTrackedAccess(t *testing.T) {
tracker := auth.NewAccessTracker() tracker := auth.NewAccessTracker()
now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC) now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC)
tracker.Record("client-a", "/files", "127.0.0.1:1234", "tester", now) tracker.Record("client-a", "/files", "127.0.0.1:1234", "tester", "list_projects", now)
snapshot := statusSnapshot(buildinfo.Info{Version: "v1.2.3"}, tracker, true, now) snapshot := statusSnapshot(buildinfo.Info{Version: "v1.2.3"}, tracker, true, now)
@@ -58,6 +58,9 @@ func TestStatusSnapshotShowsTrackedAccess(t *testing.T) {
if snapshot.Metrics.UniqueAgents != 1 { if snapshot.Metrics.UniqueAgents != 1 {
t.Fatalf("Metrics.UniqueAgents = %d, want 1", snapshot.Metrics.UniqueAgents) t.Fatalf("Metrics.UniqueAgents = %d, want 1", snapshot.Metrics.UniqueAgents)
} }
if snapshot.Metrics.UniqueTools != 1 {
t.Fatalf("Metrics.UniqueTools = %d, want 1", snapshot.Metrics.UniqueTools)
}
} }
func TestStatusAPIHandlerReturnsJSON(t *testing.T) { func TestStatusAPIHandlerReturnsJSON(t *testing.T) {

View File

@@ -22,6 +22,7 @@ type AccessTracker struct {
entries map[string]AccessSnapshot entries map[string]AccessSnapshot
ipCounts map[string]int ipCounts map[string]int
agentCounts map[string]int agentCounts map[string]int
toolCounts map[string]int
totalRequests int totalRequests int
} }
@@ -30,10 +31,11 @@ func NewAccessTracker() *AccessTracker {
entries: make(map[string]AccessSnapshot), entries: make(map[string]AccessSnapshot),
ipCounts: make(map[string]int), ipCounts: make(map[string]int),
agentCounts: make(map[string]int), agentCounts: make(map[string]int),
toolCounts: make(map[string]int),
} }
} }
func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent string, now time.Time) { func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent, toolName string, now time.Time) {
if t == nil || keyID == "" { if t == nil || keyID == "" {
return return
} }
@@ -59,6 +61,9 @@ func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent string, now ti
if userAgent != "" { if userAgent != "" {
t.agentCounts[userAgent]++ t.agentCounts[userAgent]++
} }
if tool := strings.TrimSpace(toolName); tool != "" {
t.toolCounts[tool]++
}
} }
func normalizeRemoteAddr(value string) string { func normalizeRemoteAddr(value string) string {
@@ -121,8 +126,10 @@ type AccessMetrics struct {
UniquePrincipals int `json:"unique_principals"` UniquePrincipals int `json:"unique_principals"`
UniqueIPs int `json:"unique_ips"` UniqueIPs int `json:"unique_ips"`
UniqueAgents int `json:"unique_agents"` UniqueAgents int `json:"unique_agents"`
UniqueTools int `json:"unique_tools"`
TopIPs []RequestAggregate `json:"top_ips"` TopIPs []RequestAggregate `json:"top_ips"`
TopAgents []RequestAggregate `json:"top_agents"` TopAgents []RequestAggregate `json:"top_agents"`
TopTools []RequestAggregate `json:"top_tools"`
} }
func (t *AccessTracker) Metrics(topN int) AccessMetrics { func (t *AccessTracker) Metrics(topN int) AccessMetrics {
@@ -141,8 +148,10 @@ func (t *AccessTracker) Metrics(topN int) AccessMetrics {
UniquePrincipals: len(t.entries), UniquePrincipals: len(t.entries),
UniqueIPs: len(t.ipCounts), UniqueIPs: len(t.ipCounts),
UniqueAgents: len(t.agentCounts), UniqueAgents: len(t.agentCounts),
UniqueTools: len(t.toolCounts),
TopIPs: topAggregates(t.ipCounts, topN), TopIPs: topAggregates(t.ipCounts, topN),
TopAgents: topAggregates(t.agentCounts, topN), TopAgents: topAggregates(t.agentCounts, topN),
TopTools: topAggregates(t.toolCounts, topN),
} }
} }

View File

@@ -10,9 +10,9 @@ func TestAccessTrackerRecordAndSnapshot(t *testing.T) {
older := time.Date(2026, 4, 4, 10, 0, 0, 0, time.UTC) older := time.Date(2026, 4, 4, 10, 0, 0, 0, time.UTC)
newer := older.Add(2 * time.Minute) newer := older.Add(2 * time.Minute)
tracker.Record("client-a", "/files", "10.0.0.1:1234", "agent-a", older) tracker.Record("client-a", "/files", "10.0.0.1:1234", "agent-a", "", older)
tracker.Record("client-b", "/mcp", "10.0.0.2:1234", "agent-b", newer) tracker.Record("client-b", "/mcp", "10.0.0.2:1234", "agent-b", "list_projects", newer)
tracker.Record("client-a", "/files/1", "10.0.0.1:1234", "agent-a2", newer.Add(30*time.Second)) tracker.Record("client-a", "/files/1", "10.0.0.1:1234", "agent-a2", "", newer.Add(30*time.Second))
snap := tracker.Snapshot() snap := tracker.Snapshot()
if len(snap) != 2 { if len(snap) != 2 {
@@ -39,8 +39,8 @@ func TestAccessTrackerConnectedCount(t *testing.T) {
tracker := NewAccessTracker() tracker := NewAccessTracker()
now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC) now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC)
tracker.Record("recent", "/mcp", "", "", now.Add(-2*time.Minute)) tracker.Record("recent", "/mcp", "", "", "", now.Add(-2*time.Minute))
tracker.Record("stale", "/mcp", "", "", now.Add(-11*time.Minute)) tracker.Record("stale", "/mcp", "", "", "", now.Add(-11*time.Minute))
if got := tracker.ConnectedCount(now, 10*time.Minute); got != 1 { if got := tracker.ConnectedCount(now, 10*time.Minute); got != 1 {
t.Fatalf("ConnectedCount() = %d, want 1", got) t.Fatalf("ConnectedCount() = %d, want 1", got)
@@ -51,10 +51,10 @@ func TestAccessTrackerMetrics(t *testing.T) {
tracker := NewAccessTracker() tracker := NewAccessTracker()
now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC) now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC)
tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", now) tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", "list_projects", now)
tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", now.Add(1*time.Second)) tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", "list_projects", now.Add(1*time.Second))
tracker.Record("client-b", "/files", "10.0.0.2:5678", "agent-b", now.Add(2*time.Second)) tracker.Record("client-b", "/files", "10.0.0.2:5678", "agent-b", "", now.Add(2*time.Second))
tracker.Record("client-c", "/files", "10.0.0.2:5678", "agent-b", now.Add(3*time.Second)) tracker.Record("client-c", "/files", "10.0.0.2:5678", "agent-b", "search_thoughts", now.Add(3*time.Second))
metrics := tracker.Metrics(5) metrics := tracker.Metrics(5)
if metrics.TotalRequests != 4 { if metrics.TotalRequests != 4 {
@@ -69,6 +69,9 @@ func TestAccessTrackerMetrics(t *testing.T) {
if metrics.UniqueAgents != 2 { if metrics.UniqueAgents != 2 {
t.Fatalf("UniqueAgents = %d, want 2", metrics.UniqueAgents) t.Fatalf("UniqueAgents = %d, want 2", metrics.UniqueAgents)
} }
if metrics.UniqueTools != 2 {
t.Fatalf("UniqueTools = %d, want 2", metrics.UniqueTools)
}
if len(metrics.TopIPs) != 2 { if len(metrics.TopIPs) != 2 {
t.Fatalf("len(TopIPs) = %d, want 2", len(metrics.TopIPs)) t.Fatalf("len(TopIPs) = %d, want 2", len(metrics.TopIPs))
} }
@@ -84,4 +87,10 @@ func TestAccessTrackerMetrics(t *testing.T) {
if metrics.TopAgents[0].RequestCount != 2 || metrics.TopAgents[1].RequestCount != 2 { if metrics.TopAgents[0].RequestCount != 2 || metrics.TopAgents[1].RequestCount != 2 {
t.Fatalf("TopAgents counts = %+v, want both counts to be 2", metrics.TopAgents) t.Fatalf("TopAgents counts = %+v, want both counts to be 2", metrics.TopAgents)
} }
if len(metrics.TopTools) != 2 {
t.Fatalf("len(TopTools) = %d, want 2", len(metrics.TopTools))
}
if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 2 {
t.Fatalf("TopTools[0] = %+v, want list_projects with count 2", metrics.TopTools[0])
}
} }

View File

@@ -1,6 +1,8 @@
package auth package auth
import ( import (
"bytes"
"encoding/json"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -8,6 +10,7 @@ import (
"testing" "testing"
"git.warky.dev/wdevs/amcs/internal/config" "git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/observability"
) )
func testLogger() *slog.Logger { func testLogger() *slog.Logger {
@@ -188,3 +191,50 @@ func TestMiddlewareRecordsForwardedRemoteAddr(t *testing.T) {
t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99") t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99")
} }
} }
func TestMiddlewareRecordsMCPToolUsage(t *testing.T) {
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
if err != nil {
t.Fatalf("NewKeyring() error = %v", err)
}
tracker := NewAccessTracker()
logger := testLogger()
authenticated := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
handler := observability.AccessLog(logger)(authenticated)
payload := map[string]any{
"jsonrpc": "2.0",
"id": "1",
"method": "tools/call",
"params": map[string]any{
"name": "list_projects",
},
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body))
req.Header.Set("x-brain-key", "secret")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
}
metrics := tracker.Metrics(10)
if metrics.UniqueTools != 1 {
t.Fatalf("UniqueTools = %d, want 1", metrics.UniqueTools)
}
if len(metrics.TopTools) != 1 {
t.Fatalf("len(TopTools) = %d, want 1", len(metrics.TopTools))
}
if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 1 {
t.Fatalf("TopTools[0] = %+v, want list_projects with count 1", metrics.TopTools[0])
}
}

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"git.warky.dev/wdevs/amcs/internal/config" "git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/observability"
"git.warky.dev/wdevs/amcs/internal/requestip" "git.warky.dev/wdevs/amcs/internal/requestip"
) )
@@ -23,7 +24,14 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
} }
recordAccess := func(r *http.Request, keyID string) { recordAccess := func(r *http.Request, keyID string) {
if tracker != nil { if tracker != nil {
tracker.Record(keyID, r.URL.Path, requestip.FromRequest(r), r.UserAgent(), time.Now()) tracker.Record(
keyID,
r.URL.Path,
requestip.FromRequest(r),
r.UserAgent(),
observability.MCPToolFromContext(r.Context()),
time.Now(),
)
} }
} }
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {

View File

@@ -106,6 +106,11 @@ func RequestIDFromContext(ctx context.Context) string {
return value return value
} }
func MCPToolFromContext(ctx context.Context) string {
value, _ := ctx.Value(mcpToolContextKey).(string)
return strings.TrimSpace(value)
}
type statusRecorder struct { type statusRecorder struct {
http.ResponseWriter http.ResponseWriter
status int status int

View File

@@ -128,8 +128,10 @@
unique_principals: raw?.metrics?.unique_principals ?? 0, unique_principals: raw?.metrics?.unique_principals ?? 0,
unique_ips: raw?.metrics?.unique_ips ?? 0, unique_ips: raw?.metrics?.unique_ips ?? 0,
unique_agents: raw?.metrics?.unique_agents ?? 0, unique_agents: raw?.metrics?.unique_agents ?? 0,
unique_tools: raw?.metrics?.unique_tools ?? 0,
top_ips: Array.isArray(raw?.metrics?.top_ips) ? raw.metrics.top_ips : [], top_ips: Array.isArray(raw?.metrics?.top_ips) ? raw.metrics.top_ips : [],
top_agents: Array.isArray(raw?.metrics?.top_agents) ? raw.metrics.top_agents : [] top_agents: Array.isArray(raw?.metrics?.top_agents) ? raw.metrics.top_agents : [],
top_tools: Array.isArray(raw?.metrics?.top_tools) ? raw.metrics.top_tools : []
} }
}; };
} catch (err) { } catch (err) {

View File

@@ -50,7 +50,7 @@
{/if} {/if}
{#if data} {#if data}
<div class="mt-6 grid gap-6 xl:grid-cols-2"> <div class="mt-6 grid gap-6 xl:grid-cols-3">
<ConnectionBreakdown <ConnectionBreakdown
title="Requests By IP Address" title="Requests By IP Address"
entries={data.metrics.top_ips} entries={data.metrics.top_ips}
@@ -61,5 +61,10 @@
entries={data.metrics.top_agents} entries={data.metrics.top_agents}
emptyLabel="No user agents recorded yet." emptyLabel="No user agents recorded yet."
/> />
<ConnectionBreakdown
title="Requests By MCP Tool"
entries={data.metrics.top_tools}
emptyLabel="No MCP tool calls recorded yet."
/>
</div> </div>
{/if} {/if}

View File

@@ -4,7 +4,7 @@
const { data }: { data: StatusResponse } = $props(); const { data }: { data: StatusResponse } = $props();
</script> </script>
<div class="mt-6 grid gap-4 sm:grid-cols-2 xl:grid-cols-3"> <div class="mt-6 grid gap-4 sm:grid-cols-2 xl:grid-cols-4">
<div class="rounded-2xl border border-white/10 bg-white/5 p-5"> <div class="rounded-2xl border border-white/10 bg-white/5 p-5">
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Connected users</p> <p class="text-sm uppercase tracking-[0.2em] text-slate-400">Connected users</p>
<p class="mt-2 text-3xl font-semibold text-white">{data.connected_count}</p> <p class="mt-2 text-3xl font-semibold text-white">{data.connected_count}</p>
@@ -25,6 +25,10 @@
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Unique agents</p> <p class="text-sm uppercase tracking-[0.2em] text-slate-400">Unique agents</p>
<p class="mt-2 text-3xl font-semibold text-white">{data.metrics.unique_agents}</p> <p class="mt-2 text-3xl font-semibold text-white">{data.metrics.unique_agents}</p>
</div> </div>
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Unique MCP tools</p>
<p class="mt-2 text-3xl font-semibold text-white">{data.metrics.unique_tools}</p>
</div>
<div class="rounded-2xl border border-white/10 bg-white/5 p-5"> <div class="rounded-2xl border border-white/10 bg-white/5 p-5">
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Version</p> <p class="text-sm uppercase tracking-[0.2em] text-slate-400">Version</p>
<p class="mt-2 break-all text-2xl font-semibold text-white">{data.version}</p> <p class="mt-2 break-all text-2xl font-semibold text-white">{data.version}</p>

View File

@@ -17,8 +17,10 @@ export type AccessMetrics = {
unique_principals: number; unique_principals: number;
unique_ips: number; unique_ips: number;
unique_agents: number; unique_agents: number;
unique_tools: number;
top_ips: RequestAggregate[]; top_ips: RequestAggregate[];
top_agents: RequestAggregate[]; top_agents: RequestAggregate[];
top_tools: RequestAggregate[];
}; };
export type StatusResponse = { export type StatusResponse = {