Some checks failed
CI / build-and-test (push) Failing after -31m49s
* Add tool tracking to AccessTracker and metrics * Update tests to validate tool tracking functionality * Modify middleware to record tool usage * Enhance observability with tool context * Update UI to display unique tools in metrics
241 lines
7.5 KiB
Go
241 lines
7.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/config"
|
|
"git.warky.dev/wdevs/amcs/internal/observability"
|
|
)
|
|
|
|
func testLogger() *slog.Logger {
|
|
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
}
|
|
|
|
func TestNewKeyringAndLookup(t *testing.T) {
|
|
_, err := NewKeyring(nil)
|
|
if err == nil {
|
|
t.Fatal("NewKeyring(nil) error = nil, want error")
|
|
}
|
|
|
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
|
|
if got, ok := keyring.Lookup("secret"); !ok || got != "client-a" {
|
|
t.Fatalf("Lookup(secret) = (%q, %v), want (client-a, true)", got, ok)
|
|
}
|
|
if _, ok := keyring.Lookup("wrong"); ok {
|
|
t.Fatal("Lookup(wrong) = true, want false")
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) {
|
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
|
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
keyID, ok := KeyIDFromContext(r.Context())
|
|
if !ok || keyID != "client-a" {
|
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
|
req.Header.Set("x-brain-key", "secret")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareAllowsBearerAuthAndSetsContext(t *testing.T) {
|
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
|
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
keyID, ok := KeyIDFromContext(r.Context())
|
|
if !ok || keyID != "client-a" {
|
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
|
req.Header.Set("Authorization", "Bearer secret")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewarePrefersExplicitHeaderOverBearerAuth(t *testing.T) {
|
|
keyring, err := NewKeyring([]config.APIKey{
|
|
{ID: "client-a", Value: "secret"},
|
|
{ID: "client-b", Value: "other-secret"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
|
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
keyID, ok := KeyIDFromContext(r.Context())
|
|
if !ok || keyID != "client-a" {
|
|
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
|
req.Header.Set("x-brain-key", "secret")
|
|
req.Header.Set("Authorization", "Bearer other-secret")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) {
|
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
|
|
handler := Middleware(config.AuthConfig{
|
|
HeaderName: "x-brain-key",
|
|
QueryParam: "key",
|
|
AllowQueryParam: true,
|
|
}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/mcp?key=secret", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) {
|
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
|
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatal("next handler should not be called")
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("missing key status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
|
}
|
|
|
|
req = httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
|
req.Header.Set("x-brain-key", "wrong")
|
|
rec = httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("invalid key status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareRecordsForwardedRemoteAddr(t *testing.T) {
|
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
tracker := NewAccessTracker()
|
|
|
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
|
req.RemoteAddr = "10.0.0.5:2222"
|
|
req.Header.Set("x-brain-key", "secret")
|
|
req.Header.Set("X-Real-IP", "203.0.113.99")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
|
}
|
|
|
|
snap := tracker.Snapshot()
|
|
if len(snap) != 1 {
|
|
t.Fatalf("len(snapshot) = %d, want 1", len(snap))
|
|
}
|
|
if snap[0].RemoteAddr != "203.0.113.99" {
|
|
t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99")
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareRecordsMCPToolUsage(t *testing.T) {
|
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
|
if err != nil {
|
|
t.Fatalf("NewKeyring() error = %v", err)
|
|
}
|
|
tracker := NewAccessTracker()
|
|
logger := testLogger()
|
|
|
|
authenticated := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
handler := observability.AccessLog(logger)(authenticated)
|
|
|
|
payload := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": "1",
|
|
"method": "tools/call",
|
|
"params": map[string]any{
|
|
"name": "list_projects",
|
|
},
|
|
}
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
t.Fatalf("json.Marshal() error = %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body))
|
|
req.Header.Set("x-brain-key", "secret")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
|
}
|
|
|
|
metrics := tracker.Metrics(10)
|
|
if metrics.UniqueTools != 1 {
|
|
t.Fatalf("UniqueTools = %d, want 1", metrics.UniqueTools)
|
|
}
|
|
if len(metrics.TopTools) != 1 {
|
|
t.Fatalf("len(TopTools) = %d, want 1", len(metrics.TopTools))
|
|
}
|
|
if metrics.TopTools[0].Key != "list_projects" || metrics.TopTools[0].RequestCount != 1 {
|
|
t.Fatalf("TopTools[0] = %+v, want list_projects with count 1", metrics.TopTools[0])
|
|
}
|
|
}
|