package auth import ( "encoding/base64" "net/http" "net/http/httptest" "testing" "git.warky.dev/wdevs/amcs/internal/config" ) func TestNewOAuthRegistryAndLookup(t *testing.T) { _, err := NewOAuthRegistry(nil) if err == nil { t.Fatal("NewOAuthRegistry(nil) error = nil, want error") } registry, err := NewOAuthRegistry([]config.OAuthClient{{ ID: "oauth-client", ClientID: "client-id", ClientSecret: "client-secret", }}) if err != nil { t.Fatalf("NewOAuthRegistry() error = %v", err) } if got, ok := registry.Lookup("client-id", "client-secret"); !ok || got != "oauth-client" { t.Fatalf("Lookup(client-id, client-secret) = (%q, %v), want (oauth-client, true)", got, ok) } if _, ok := registry.Lookup("client-id", "wrong"); ok { t.Fatal("Lookup(client-id, wrong) = true, want false") } } func TestMiddlewareAllowsOAuthBasicAuthAndSetsContext(t *testing.T) { oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{ ID: "oauth-client", ClientID: "client-id", ClientSecret: "client-secret", }}) if err != nil { t.Fatalf("NewOAuthRegistry() error = %v", err) } handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { keyID, ok := KeyIDFromContext(r.Context()) if !ok || keyID != "oauth-client" { t.Fatalf("KeyIDFromContext() = (%q, %v), want (oauth-client, true)", keyID, ok) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:client-secret"))) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } } func TestMiddlewareRejectsOAuthMissingOrInvalidCredentials(t *testing.T) { oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{ ID: "oauth-client", ClientID: "client-id", ClientSecret: "client-secret", }}) if err != nil { t.Fatalf("NewOAuthRegistry() error = %v", err) } handler := Middleware(config.AuthConfig{}, nil, oauthRegistry, nil, nil, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("next handler should not be called") })) req := httptest.NewRequest(http.MethodGet, "/mcp", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("missing credentials status = %d, want %d", rec.Code, http.StatusUnauthorized) } req = httptest.NewRequest(http.MethodGet, "/mcp", nil) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("client-id:wrong"))) rec = httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("invalid credentials status = %d, want %d", rec.Code, http.StatusUnauthorized) } }