diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 3c546fda2..8fac12a28 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -12,7 +12,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field from typing import Any, Protocol -from urllib.parse import quote, urlencode, urljoin, urlparse +from urllib.parse import quote, urljoin, urlparse import anyio import httpx @@ -353,7 +353,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if "offline_access" in self.context.client_metadata.scope.split(): auth_params["prompt"] = "consent" - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + authorization_url = str(httpx.URL(auth_endpoint).copy_merge_params(auth_params)) await self.context.redirect_handler(authorization_url) # Wait for callback diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c9..fc9b3b2e1 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1167,6 +1167,52 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide assert oauth_provider.context.current_tokens.access_token == "new_access_token" assert oauth_provider.context.token_expiry_time is not None + @pytest.mark.anyio + async def test_authorization_endpoint_preserves_existing_query_params( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + captured_state = parse_qs(urlparse(url).query)["state"][0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize?prompt=select_account"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + auth_code, code_verifier = await provider._perform_authorization_code_grant() + + assert auth_code == "test_auth_code" + assert code_verifier + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == "https://auth.example.com/authorize" + assert params["prompt"] == ["select_account"] + assert params["response_type"] == ["code"] + assert params["client_id"] == ["test_client"] + assert params["redirect_uri"] == ["http://localhost:3030/callback"] + @pytest.mark.anyio async def test_auth_flow_no_unnecessary_retry_after_oauth( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken @@ -1350,13 +1396,9 @@ async def test_403_insufficient_scope_updates_scope_from_header( async def capture_redirect(url: str) -> None: nonlocal redirect_captured, captured_state redirect_captured = True - # Verify the new scope is included in authorization URL - assert "scope=admin%3Awrite+admin%3Adelete" in url or "scope=admin:write+admin:delete" in url.replace( - "%3A", ":" - ).replace("+", " ") - # Extract state from redirect URL parsed = urlparse(url) params = parse_qs(parsed.query) + assert params["scope"] == ["admin:write admin:delete"] captured_state = params.get("state", [None])[0] oauth_provider.context.redirect_handler = capture_redirect