From 850ad2b2ab3d0b466e37fab479505de54428a649 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 9 Apr 2026 14:04:53 +0000 Subject: [PATCH] fix(security): address all OAuth2 PR review issues Agent-Logs-Url: https://github.com/bitechdev/ResolveSpec/sessions/e886b781-c910-425f-aa6f-06d13c46dcc7 Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com> --- pkg/resolvemcp/README.md | 5 +- pkg/resolvemcp/oauth2.go | 2 +- pkg/security/README.md | 17 +++--- pkg/security/database_schema.sql | 11 +++- pkg/security/oauth_server.go | 94 ++++++++++++++++++++++++-------- pkg/security/oauth_server_db.go | 16 +++--- 6 files changed, 100 insertions(+), 45 deletions(-) diff --git a/pkg/resolvemcp/README.md b/pkg/resolvemcp/README.md index 660599f..2e3d99f 100644 --- a/pkg/resolvemcp/README.md +++ b/pkg/resolvemcp/README.md @@ -217,10 +217,11 @@ auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{ ProviderName: "google", }) -// nil = no password login; Google handles auth +// Pass `auth` so the OAuth server supports persistence, introspection, and revocation. +// Google handles the end-user authentication flow via redirect. handler.EnableOAuthServer(security.OAuthServerConfig{ Issuer: "https://api.example.com", -}, nil) +}, auth) handler.RegisterOAuth2Provider(auth, "google") ``` diff --git a/pkg/resolvemcp/oauth2.go b/pkg/resolvemcp/oauth2.go index 5ffa07e..56e7a8d 100644 --- a/pkg/resolvemcp/oauth2.go +++ b/pkg/resolvemcp/oauth2.go @@ -45,7 +45,7 @@ func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth // // auth := security.NewGoogleAuthenticator(...) // handler.RegisterOAuth2(auth, cfg) -// handler.EnableOAuthServer(resolvemcp.OAuthServerConfig{Issuer: "https://api.example.com"}) +// handler.EnableOAuthServer(security.OAuthServerConfig{Issuer: "https://api.example.com"}) // security.RegisterSecurityHooks(handler, securityList) // http.ListenAndServe(":8080", handler.HTTPHandler(securityList)) func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler { diff --git a/pkg/security/README.md b/pkg/security/README.md index 4891fef..0df1939 100644 --- a/pkg/security/README.md +++ b/pkg/security/README.md @@ -938,14 +938,14 @@ cfg := security.OAuthServerConfig{ | Field | Default | Notes | |-------|---------|-------| -| `Issuer` | — | Required | +| `Issuer` | — | Required; trailing slash is trimmed automatically | | `ProviderCallbackPath` | `/oauth/provider/callback` | | -| `LoginTitle` | `"Login"` | | +| `LoginTitle` | `"Sign in"` | | | `PersistClients` | `false` | Set `true` for multi-instance | -| `PersistCodes` | `false` | Set `true` for multi-instance | -| `DefaultScopes` | `nil` | | -| `AccessTokenTTL` | `1h` | | -| `AuthCodeTTL` | `5m` | | +| `PersistCodes` | `false` | Set `true` for multi-instance; does not require `PersistClients` | +| `DefaultScopes` | `["openid","profile","email"]` | | +| `AccessTokenTTL` | `24h` | Also used as `expires_in` in token responses | +| `AuthCodeTTL` | `2m` | | ### Operating Modes @@ -960,10 +960,11 @@ srv := security.NewOAuthServer(cfg, auth) **Mode 2 — External provider federation** -Pass `nil` as auth and register external providers. The authorize page shows a provider selection UI. +Pass a `*DatabaseAuthenticator` for persistence (authorization codes, revoke, introspect) and register external providers. The authorize endpoint redirects to the specified provider (via the `provider` query param) or to the first registered provider by default. ```go -srv := security.NewOAuthServer(cfg, nil) +auth := security.NewDatabaseAuthenticator(db) +srv := security.NewOAuthServer(cfg, auth) srv.RegisterExternalProvider(googleAuth, "google") srv.RegisterExternalProvider(githubAuth, "github") ``` diff --git a/pkg/security/database_schema.sql b/pkg/security/database_schema.sql index bbddb9b..9a652a1 100644 --- a/pkg/security/database_schema.sql +++ b/pkg/security/database_schema.sql @@ -1415,15 +1415,18 @@ CREATE TABLE IF NOT EXISTS oauth_clients ( ); -- oauth_codes: short-lived authorization codes (for multi-instance deployments) +-- Note: client_id is stored without a foreign key so codes can be persisted even +-- when OAuth clients are managed in memory rather than persisted in oauth_clients. CREATE TABLE IF NOT EXISTS oauth_codes ( id SERIAL PRIMARY KEY, code VARCHAR(255) NOT NULL UNIQUE, - client_id VARCHAR(255) NOT NULL REFERENCES oauth_clients(client_id) ON DELETE CASCADE, + client_id VARCHAR(255) NOT NULL, redirect_uri TEXT NOT NULL, client_state TEXT, code_challenge VARCHAR(255) NOT NULL, code_challenge_method VARCHAR(10) DEFAULT 'S256', session_token TEXT NOT NULL, + refresh_token TEXT, scopes TEXT[], expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP @@ -1483,7 +1486,7 @@ CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb) RETURNS TABLE(p_success bool, p_error text) LANGUAGE plpgsql AS $$ BEGIN - INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, scopes, expires_at) + INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, refresh_token, scopes, expires_at) VALUES ( p_data->>'code', p_data->>'client_id', @@ -1492,6 +1495,7 @@ BEGIN p_data->>'code_challenge', COALESCE(p_data->>'code_challenge_method', 'S256'), p_data->>'session_token', + p_data->>'refresh_token', ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')), (p_data->>'expires_at')::timestamp ); @@ -1517,6 +1521,7 @@ BEGIN 'code_challenge', code_challenge, 'code_challenge_method', code_challenge_method, 'session_token', session_token, + 'refresh_token', refresh_token, 'scopes', to_jsonb(scopes) ) INTO v_row; @@ -1540,7 +1545,7 @@ BEGIN 'username', u.username, 'email', u.email, 'user_level', u.user_level, - 'roles', to_jsonb(string_to_array(COALESCE(u.roles, ''), ',')), + 'roles', COALESCE(to_jsonb(string_to_array(NULLIF(u.roles, ''), ',')), '[]'::jsonb), 'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint, 'iat', EXTRACT(EPOCH FROM s.created_at)::bigint ) diff --git a/pkg/security/oauth_server.go b/pkg/security/oauth_server.go index 7fff22f..7ba074c 100644 --- a/pkg/security/oauth_server.go +++ b/pkg/security/oauth_server.go @@ -25,7 +25,7 @@ type OAuthServerConfig struct { ProviderCallbackPath string // LoginTitle is shown on the built-in login form when the server acts as its own - // identity provider. Defaults to "MCP Login". + // identity provider. Defaults to "Sign in". LoginTitle string // PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided. @@ -65,6 +65,7 @@ type pendingAuth struct { ProviderName string // empty = password login ExpiresAt time.Time SessionToken string // set after authentication completes + RefreshToken string // set after authentication completes when refresh tokens are issued Scopes []string // requested scopes } @@ -100,6 +101,8 @@ type OAuthServer struct { clients map[string]*oauthClient pending map[string]*pendingAuth // provider_state → pending (external flow) codes map[string]*pendingAuth // auth_code → pending (post-auth) + + done chan struct{} // closed by Close() to stop background goroutines } // NewOAuthServer creates a new MCP OAuth2 authorization server. @@ -107,6 +110,8 @@ type OAuthServer struct { // Pass a DatabaseAuthenticator to enable direct username/password login (the server // acts as its own identity provider). Pass nil to use only external providers. // External providers are added separately via RegisterExternalProvider. +// +// Call Close() to stop background goroutines when the server is no longer needed. func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer { if cfg.ProviderCallbackPath == "" { cfg.ProviderCallbackPath = "/oauth/provider/callback" @@ -123,23 +128,40 @@ func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthSe if cfg.AuthCodeTTL == 0 { cfg.AuthCodeTTL = 2 * time.Minute } + // Normalize issuer: trim trailing slash to ensure consistent endpoint URL construction. + cfg.Issuer = strings.TrimRight(cfg.Issuer, "/") s := &OAuthServer{ cfg: cfg, auth: auth, clients: make(map[string]*oauthClient), pending: make(map[string]*pendingAuth), codes: make(map[string]*pendingAuth), + done: make(chan struct{}), } go s.cleanupExpired() return s } +// Close stops the background goroutines started by NewOAuthServer. +// It is safe to call Close multiple times. +func (s *OAuthServer) Close() { + select { + case <-s.done: + // already closed + default: + close(s.done) + } +} + // RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.) // that handles user authentication via redirect. The DatabaseAuthenticator must have been // configured with WithOAuth2(providerName, ...) before calling this. // Multiple providers can be registered; the first is used as the default. +// All providers must be registered before the server starts serving requests. func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) { + s.mu.Lock() s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName}) + s.mu.Unlock() } // ProviderCallbackPath returns the configured path for external provider callbacks. @@ -169,20 +191,25 @@ func (s *OAuthServer) HTTPHandler() http.Handler { func (s *OAuthServer) cleanupExpired() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() - for range ticker.C { - now := time.Now() - s.mu.Lock() - for k, p := range s.pending { - if now.After(p.ExpiresAt) { - delete(s.pending, k) + for { + select { + case <-s.done: + return + case <-ticker.C: + now := time.Now() + s.mu.Lock() + for k, p := range s.pending { + if now.After(p.ExpiresAt) { + delete(s.pending, k) + } } - } - for k, p := range s.codes { - if now.After(p.ExpiresAt) { - delete(s.codes, k) + for k, p := range s.codes { + if now.After(p.ExpiresAt) { + delete(s.codes, k) + } } + s.mu.Unlock() } - s.mu.Unlock() } } @@ -383,7 +410,7 @@ func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) { return } - s.issueCodeAndRedirect(w, r, loginResp.Token, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes) + s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes) } // redirectToExternalProvider stores the pending auth and redirects to the configured provider. @@ -469,13 +496,13 @@ func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Req return } - s.issueCodeAndRedirect(w, r, loginResp.Token, + s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken, pending.ClientID, pending.RedirectURI, pending.ClientState, pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes) } // issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client. -func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) { +func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, refreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) { authCode, err := randomOAuthToken() if err != nil { http.Error(w, "server error", http.StatusInternalServerError) @@ -490,6 +517,7 @@ func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Reques CodeChallengeMethod: codeChallengeMethod, ProviderName: providerName, SessionToken: sessionToken, + RefreshToken: refreshToken, ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL), Scopes: scopes, } @@ -503,6 +531,7 @@ func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Reques CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, SessionToken: sessionToken, + RefreshToken: refreshToken, Scopes: scopes, ExpiresAt: pending.ExpiresAt, } @@ -565,6 +594,7 @@ func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request } var sessionToken string + var refreshToken string var scopes []string if s.cfg.PersistCodes && s.auth != nil { @@ -586,6 +616,7 @@ func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request return } sessionToken = oauthCode.SessionToken + refreshToken = oauthCode.RefreshToken scopes = oauthCode.Scopes } else { s.mu.Lock() @@ -612,10 +643,11 @@ func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request return } sessionToken = pending.SessionToken + refreshToken = pending.RefreshToken scopes = pending.Scopes } - writeOAuthToken(w, sessionToken, "", scopes) + s.writeOAuthToken(w, sessionToken, refreshToken, scopes) } func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) { @@ -634,7 +666,7 @@ func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest) return } - writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil) + s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil) return } @@ -644,7 +676,7 @@ func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest) return } - writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil) + s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil) return } @@ -672,6 +704,9 @@ func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) { if s.auth != nil { s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck + } else if len(s.providers) > 0 { + // In external-provider-only mode, attempt revocation via the first provider's auth. + s.providers[0].auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck } w.WriteHeader(http.StatusOK) } @@ -693,12 +728,22 @@ func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) token := r.FormValue("token") w.Header().Set("Content-Type", "application/json") - if token == "" || s.auth == nil { + if token == "" { w.Write([]byte(`{"active":false}`)) //nolint:errcheck return } - info, err := s.auth.OAuthIntrospectToken(r.Context(), token) + // Resolve the authenticator to use: prefer the primary auth, then the first provider's auth. + authToUse := s.auth + if authToUse == nil && len(s.providers) > 0 { + authToUse = s.providers[0].auth + } + if authToUse == nil { + w.Write([]byte(`{"active":false}`)) //nolint:errcheck + return + } + + info, err := authToUse.OAuthIntrospectToken(r.Context(), token) if err != nil { w.Write([]byte(`{"active":false}`)) //nolint:errcheck return @@ -740,7 +785,7 @@ button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;borde button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}

%s

%s -
+ @@ -815,18 +860,19 @@ func randomOAuthToken() (string, error) { func oauthSliceContains(slice []string, s string) bool { for _, v := range slice { - if strings.EqualFold(v, s) { + if v == s { return true } } return false } -func writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) { +func (s *OAuthServer) writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) { + expiresIn := int64(s.cfg.AccessTokenTTL.Seconds()) resp := map[string]interface{}{ "access_token": accessToken, "token_type": "Bearer", - "expires_in": 86400, + "expires_in": expiresIn, } if refreshToken != "" { resp["refresh_token"] = refreshToken diff --git a/pkg/security/oauth_server_db.go b/pkg/security/oauth_server_db.go index 1206871..4f9a095 100644 --- a/pkg/security/oauth_server_db.go +++ b/pkg/security/oauth_server_db.go @@ -25,19 +25,21 @@ type OAuthCode struct { CodeChallenge string `json:"code_challenge"` CodeChallengeMethod string `json:"code_challenge_method"` SessionToken string `json:"session_token"` + RefreshToken string `json:"refresh_token,omitempty"` Scopes []string `json:"scopes,omitempty"` ExpiresAt time.Time `json:"expires_at"` } // OAuthTokenInfo is the RFC 7662 token introspection response. type OAuthTokenInfo struct { - Active bool `json:"active"` - Sub string `json:"sub,omitempty"` - Username string `json:"username,omitempty"` - Email string `json:"email,omitempty"` - Roles []string `json:"roles,omitempty"` - Exp int64 `json:"exp,omitempty"` - Iat int64 `json:"iat,omitempty"` + Active bool `json:"active"` + Sub string `json:"sub,omitempty"` + Username string `json:"username,omitempty"` + Email string `json:"email,omitempty"` + UserLevel int `json:"user_level,omitempty"` + Roles []string `json:"roles,omitempty"` + Exp int64 `json:"exp,omitempty"` + Iat int64 `json:"iat,omitempty"` } // OAuthRegisterClient persists an OAuth2 client registration.