mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-11 02:13:53 +00:00
fix(security): address all OAuth2 PR review issues
Agent-Logs-Url: https://github.com/bitechdev/ResolveSpec/sessions/e886b781-c910-425f-aa6f-06d13c46dcc7 Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2a2e33da0c
commit
850ad2b2ab
@@ -25,7 +25,7 @@ type OAuthServerConfig struct {
|
||||
ProviderCallbackPath string
|
||||
|
||||
// LoginTitle is shown on the built-in login form when the server acts as its own
|
||||
// identity provider. Defaults to "MCP Login".
|
||||
// identity provider. Defaults to "Sign in".
|
||||
LoginTitle string
|
||||
|
||||
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
|
||||
@@ -65,6 +65,7 @@ type pendingAuth struct {
|
||||
ProviderName string // empty = password login
|
||||
ExpiresAt time.Time
|
||||
SessionToken string // set after authentication completes
|
||||
RefreshToken string // set after authentication completes when refresh tokens are issued
|
||||
Scopes []string // requested scopes
|
||||
}
|
||||
|
||||
@@ -100,6 +101,8 @@ type OAuthServer struct {
|
||||
clients map[string]*oauthClient
|
||||
pending map[string]*pendingAuth // provider_state → pending (external flow)
|
||||
codes map[string]*pendingAuth // auth_code → pending (post-auth)
|
||||
|
||||
done chan struct{} // closed by Close() to stop background goroutines
|
||||
}
|
||||
|
||||
// NewOAuthServer creates a new MCP OAuth2 authorization server.
|
||||
@@ -107,6 +110,8 @@ type OAuthServer struct {
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
|
||||
// acts as its own identity provider). Pass nil to use only external providers.
|
||||
// External providers are added separately via RegisterExternalProvider.
|
||||
//
|
||||
// Call Close() to stop background goroutines when the server is no longer needed.
|
||||
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
|
||||
if cfg.ProviderCallbackPath == "" {
|
||||
cfg.ProviderCallbackPath = "/oauth/provider/callback"
|
||||
@@ -123,23 +128,40 @@ func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthSe
|
||||
if cfg.AuthCodeTTL == 0 {
|
||||
cfg.AuthCodeTTL = 2 * time.Minute
|
||||
}
|
||||
// Normalize issuer: trim trailing slash to ensure consistent endpoint URL construction.
|
||||
cfg.Issuer = strings.TrimRight(cfg.Issuer, "/")
|
||||
s := &OAuthServer{
|
||||
cfg: cfg,
|
||||
auth: auth,
|
||||
clients: make(map[string]*oauthClient),
|
||||
pending: make(map[string]*pendingAuth),
|
||||
codes: make(map[string]*pendingAuth),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go s.cleanupExpired()
|
||||
return s
|
||||
}
|
||||
|
||||
// Close stops the background goroutines started by NewOAuthServer.
|
||||
// It is safe to call Close multiple times.
|
||||
func (s *OAuthServer) Close() {
|
||||
select {
|
||||
case <-s.done:
|
||||
// already closed
|
||||
default:
|
||||
close(s.done)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
|
||||
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
|
||||
// configured with WithOAuth2(providerName, ...) before calling this.
|
||||
// Multiple providers can be registered; the first is used as the default.
|
||||
// All providers must be registered before the server starts serving requests.
|
||||
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
|
||||
s.mu.Lock()
|
||||
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// ProviderCallbackPath returns the configured path for external provider callbacks.
|
||||
@@ -169,20 +191,25 @@ func (s *OAuthServer) HTTPHandler() http.Handler {
|
||||
func (s *OAuthServer) cleanupExpired() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, p := range s.pending {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.pending, k)
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, p := range s.pending {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.pending, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
for k, p := range s.codes {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.codes, k)
|
||||
for k, p := range s.codes {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.codes, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -383,7 +410,7 @@ func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
|
||||
}
|
||||
|
||||
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
|
||||
@@ -469,13 +496,13 @@ func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Req
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token,
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken,
|
||||
pending.ClientID, pending.RedirectURI, pending.ClientState,
|
||||
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
|
||||
}
|
||||
|
||||
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
|
||||
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, refreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
authCode, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
@@ -490,6 +517,7 @@ func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Reques
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: providerName,
|
||||
SessionToken: sessionToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
|
||||
Scopes: scopes,
|
||||
}
|
||||
@@ -503,6 +531,7 @@ func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Reques
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
SessionToken: sessionToken,
|
||||
RefreshToken: refreshToken,
|
||||
Scopes: scopes,
|
||||
ExpiresAt: pending.ExpiresAt,
|
||||
}
|
||||
@@ -565,6 +594,7 @@ func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
var sessionToken string
|
||||
var refreshToken string
|
||||
var scopes []string
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
@@ -586,6 +616,7 @@ func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
sessionToken = oauthCode.SessionToken
|
||||
refreshToken = oauthCode.RefreshToken
|
||||
scopes = oauthCode.Scopes
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
@@ -612,10 +643,11 @@ func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
sessionToken = pending.SessionToken
|
||||
refreshToken = pending.RefreshToken
|
||||
scopes = pending.Scopes
|
||||
}
|
||||
|
||||
writeOAuthToken(w, sessionToken, "", scopes)
|
||||
s.writeOAuthToken(w, sessionToken, refreshToken, scopes)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -634,7 +666,7 @@ func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request)
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -644,7 +676,7 @@ func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request)
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -672,6 +704,9 @@ func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if s.auth != nil {
|
||||
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
|
||||
} else if len(s.providers) > 0 {
|
||||
// In external-provider-only mode, attempt revocation via the first provider's auth.
|
||||
s.providers[0].auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
@@ -693,12 +728,22 @@ func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request)
|
||||
token := r.FormValue("token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if token == "" || s.auth == nil {
|
||||
if token == "" {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.auth.OAuthIntrospectToken(r.Context(), token)
|
||||
// Resolve the authenticator to use: prefer the primary auth, then the first provider's auth.
|
||||
authToUse := s.auth
|
||||
if authToUse == nil && len(s.providers) > 0 {
|
||||
authToUse = s.providers[0].auth
|
||||
}
|
||||
if authToUse == nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
info, err := authToUse.OAuthIntrospectToken(r.Context(), token)
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
@@ -740,7 +785,7 @@ button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;borde
|
||||
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
|
||||
</head><body><div class="card">
|
||||
<h2>%s</h2>%s
|
||||
<form method="POST" action="/oauth/authorize">
|
||||
<form method="POST" action="authorize">
|
||||
<input type="hidden" name="client_id" value="%s">
|
||||
<input type="hidden" name="redirect_uri" value="%s">
|
||||
<input type="hidden" name="client_state" value="%s">
|
||||
@@ -815,18 +860,19 @@ func randomOAuthToken() (string, error) {
|
||||
|
||||
func oauthSliceContains(slice []string, s string) bool {
|
||||
for _, v := range slice {
|
||||
if strings.EqualFold(v, s) {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
|
||||
func (s *OAuthServer) writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
|
||||
expiresIn := int64(s.cfg.AccessTokenTTL.Seconds())
|
||||
resp := map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
"expires_in": expiresIn,
|
||||
}
|
||||
if refreshToken != "" {
|
||||
resp["refresh_token"] = refreshToken
|
||||
|
||||
Reference in New Issue
Block a user