feat(auth): track unique tools in access metrics
Some checks failed
CI / build-and-test (push) Failing after -31m49s

* Add tool tracking to AccessTracker and metrics
* Update tests to validate tool tracking functionality
* Modify middleware to record tool usage
* Enhance observability with tool context
* Update UI to display unique tools in metrics
This commit is contained in:
2026-04-26 23:25:51 +02:00
parent 63f8dcacb6
commit b17241b928
10 changed files with 112 additions and 15 deletions

View File

@@ -33,7 +33,7 @@ func TestStatusSnapshotHidesOAuthLinkWhenDisabled(t *testing.T) {
func TestStatusSnapshotShowsTrackedAccess(t *testing.T) {
tracker := auth.NewAccessTracker()
now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC)
tracker.Record("client-a", "/files", "127.0.0.1:1234", "tester", now)
tracker.Record("client-a", "/files", "127.0.0.1:1234", "tester", "list_projects", now)
snapshot := statusSnapshot(buildinfo.Info{Version: "v1.2.3"}, tracker, true, now)
@@ -58,6 +58,9 @@ func TestStatusSnapshotShowsTrackedAccess(t *testing.T) {
if snapshot.Metrics.UniqueAgents != 1 {
t.Fatalf("Metrics.UniqueAgents = %d, want 1", snapshot.Metrics.UniqueAgents)
}
if snapshot.Metrics.UniqueTools != 1 {
t.Fatalf("Metrics.UniqueTools = %d, want 1", snapshot.Metrics.UniqueTools)
}
}
func TestStatusAPIHandlerReturnsJSON(t *testing.T) {

View File

@@ -22,6 +22,7 @@ type AccessTracker struct {
entries map[string]AccessSnapshot
ipCounts map[string]int
agentCounts map[string]int
toolCounts map[string]int
totalRequests int
}
@@ -30,10 +31,11 @@ func NewAccessTracker() *AccessTracker {
entries: make(map[string]AccessSnapshot),
ipCounts: make(map[string]int),
agentCounts: make(map[string]int),
toolCounts: make(map[string]int),
}
}
func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent string, now time.Time) {
func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent, toolName string, now time.Time) {
if t == nil || keyID == "" {
return
}
@@ -59,6 +61,9 @@ func (t *AccessTracker) Record(keyID, path, remoteAddr, userAgent string, now ti
if userAgent != "" {
t.agentCounts[userAgent]++
}
if tool := strings.TrimSpace(toolName); tool != "" {
t.toolCounts[tool]++
}
}
func normalizeRemoteAddr(value string) string {
@@ -121,8 +126,10 @@ type AccessMetrics struct {
UniquePrincipals int `json:"unique_principals"`
UniqueIPs int `json:"unique_ips"`
UniqueAgents int `json:"unique_agents"`
UniqueTools int `json:"unique_tools"`
TopIPs []RequestAggregate `json:"top_ips"`
TopAgents []RequestAggregate `json:"top_agents"`
TopTools []RequestAggregate `json:"top_tools"`
}
func (t *AccessTracker) Metrics(topN int) AccessMetrics {
@@ -141,8 +148,10 @@ func (t *AccessTracker) Metrics(topN int) AccessMetrics {
UniquePrincipals: len(t.entries),
UniqueIPs: len(t.ipCounts),
UniqueAgents: len(t.agentCounts),
UniqueTools: len(t.toolCounts),
TopIPs: topAggregates(t.ipCounts, topN),
TopAgents: topAggregates(t.agentCounts, topN),
TopTools: topAggregates(t.toolCounts, topN),
}
}

View File

@@ -10,9 +10,9 @@ func TestAccessTrackerRecordAndSnapshot(t *testing.T) {
older := time.Date(2026, 4, 4, 10, 0, 0, 0, time.UTC)
newer := older.Add(2 * time.Minute)
tracker.Record("client-a", "/files", "10.0.0.1:1234", "agent-a", older)
tracker.Record("client-b", "/mcp", "10.0.0.2:1234", "agent-b", newer)
tracker.Record("client-a", "/files/1", "10.0.0.1:1234", "agent-a2", newer.Add(30*time.Second))
tracker.Record("client-a", "/files", "10.0.0.1:1234", "agent-a", "", older)
tracker.Record("client-b", "/mcp", "10.0.0.2:1234", "agent-b", "list_projects", newer)
tracker.Record("client-a", "/files/1", "10.0.0.1:1234", "agent-a2", "", newer.Add(30*time.Second))
snap := tracker.Snapshot()
if len(snap) != 2 {
@@ -39,8 +39,8 @@ func TestAccessTrackerConnectedCount(t *testing.T) {
tracker := NewAccessTracker()
now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC)
tracker.Record("recent", "/mcp", "", "", now.Add(-2*time.Minute))
tracker.Record("stale", "/mcp", "", "", now.Add(-11*time.Minute))
tracker.Record("recent", "/mcp", "", "", "", now.Add(-2*time.Minute))
tracker.Record("stale", "/mcp", "", "", "", now.Add(-11*time.Minute))
if got := tracker.ConnectedCount(now, 10*time.Minute); got != 1 {
t.Fatalf("ConnectedCount() = %d, want 1", got)
@@ -51,10 +51,10 @@ func TestAccessTrackerMetrics(t *testing.T) {
tracker := NewAccessTracker()
now := time.Date(2026, 4, 4, 12, 0, 0, 0, time.UTC)
tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", now)
tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", now.Add(1*time.Second))
tracker.Record("client-b", "/files", "10.0.0.2:5678", "agent-b", now.Add(2*time.Second))
tracker.Record("client-c", "/files", "10.0.0.2:5678", "agent-b", now.Add(3*time.Second))
tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", "list_projects", now)
tracker.Record("client-a", "/mcp", "10.0.0.1:1234", "agent-a", "list_projects", now.Add(1*time.Second))
tracker.Record("client-b", "/files", "10.0.0.2:5678", "agent-b", "", now.Add(2*time.Second))
tracker.Record("client-c", "/files", "10.0.0.2:5678", "agent-b", "search_thoughts", now.Add(3*time.Second))
metrics := tracker.Metrics(5)
if metrics.TotalRequests != 4 {
@@ -69,6 +69,9 @@ func TestAccessTrackerMetrics(t *testing.T) {
if metrics.UniqueAgents != 2 {
t.Fatalf("UniqueAgents = %d, want 2", metrics.UniqueAgents)
}
if metrics.UniqueTools != 2 {
t.Fatalf("UniqueTools = %d, want 2", metrics.UniqueTools)
}
if len(metrics.TopIPs) != 2 {
t.Fatalf("len(TopIPs) = %d, want 2", len(metrics.TopIPs))
}
@@ -84,4 +87,10 @@ func TestAccessTrackerMetrics(t *testing.T) {
if metrics.TopAgents[0].RequestCount != 2 || metrics.TopAgents[1].RequestCount != 2 {
t.Fatalf("TopAgents counts = %+v, want both counts to be 2", metrics.TopAgents)
}
if len(metrics.TopTools) != 2 {
t.Fatalf("len(TopTools) = %d, want 2", len(metrics.TopTools))
}
if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 2 {
t.Fatalf("TopTools[0] = %+v, want list_projects with count 2", metrics.TopTools[0])
}
}

View File

@@ -1,6 +1,8 @@
package auth
import (
"bytes"
"encoding/json"
"io"
"log/slog"
"net/http"
@@ -8,6 +10,7 @@ import (
"testing"
"git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/observability"
)
func testLogger() *slog.Logger {
@@ -188,3 +191,50 @@ func TestMiddlewareRecordsForwardedRemoteAddr(t *testing.T) {
t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99")
}
}
func TestMiddlewareRecordsMCPToolUsage(t *testing.T) {
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
if err != nil {
t.Fatalf("NewKeyring() error = %v", err)
}
tracker := NewAccessTracker()
logger := testLogger()
authenticated := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
handler := observability.AccessLog(logger)(authenticated)
payload := map[string]any{
"jsonrpc": "2.0",
"id": "1",
"method": "tools/call",
"params": map[string]any{
"name": "list_projects",
},
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body))
req.Header.Set("x-brain-key", "secret")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
}
metrics := tracker.Metrics(10)
if metrics.UniqueTools != 1 {
t.Fatalf("UniqueTools = %d, want 1", metrics.UniqueTools)
}
if len(metrics.TopTools) != 1 {
t.Fatalf("len(TopTools) = %d, want 1", len(metrics.TopTools))
}
if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 1 {
t.Fatalf("TopTools[0] = %+v, want list_projects with count 1", metrics.TopTools[0])
}
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/observability"
"git.warky.dev/wdevs/amcs/internal/requestip"
)
@@ -23,7 +24,14 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
}
recordAccess := func(r *http.Request, keyID string) {
if tracker != nil {
tracker.Record(keyID, r.URL.Path, requestip.FromRequest(r), r.UserAgent(), time.Now())
tracker.Record(
keyID,
r.URL.Path,
requestip.FromRequest(r),
r.UserAgent(),
observability.MCPToolFromContext(r.Context()),
time.Now(),
)
}
}
return func(next http.Handler) http.Handler {

View File

@@ -106,6 +106,11 @@ func RequestIDFromContext(ctx context.Context) string {
return value
}
func MCPToolFromContext(ctx context.Context) string {
value, _ := ctx.Value(mcpToolContextKey).(string)
return strings.TrimSpace(value)
}
type statusRecorder struct {
http.ResponseWriter
status int