diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 3c546fda2..187251f54 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -9,10 +9,10 @@ import secrets import string import time -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass, field from typing import Any, Protocol -from urllib.parse import quote, urlencode, urljoin, urlparse +from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse import anyio import httpx @@ -53,6 +53,22 @@ logger = logging.getLogger(__name__) +def _build_authorization_url(auth_endpoint: str, auth_params: Mapping[str, str | None]) -> str: + """Build an authorization URL, preserving any query params already on the endpoint. + + Servers may advertise an ``authorization_endpoint`` that already carries query + parameters (e.g. ``https://example.com/authorize?prompt=select_account``). + Naively appending ``?`` would produce an invalid URL with two ``?`` + separators, so the existing query is parsed and merged with ``auth_params``. + Flow-generated params take precedence on key conflicts; ``None`` values are + dropped rather than serialized as the literal string ``"None"``. + """ + parsed = urlparse(auth_endpoint) + merged_params = dict(parse_qsl(parsed.query, keep_blank_values=True)) + merged_params.update({key: value for key, value in auth_params.items() if value is not None}) + return urlunparse(parsed._replace(query=urlencode(merged_params))) + + class PKCEParameters(BaseModel): """PKCE (Proof Key for Code Exchange) parameters.""" @@ -353,7 +369,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if "offline_access" in self.context.client_metadata.scope.split(): auth_params["prompt"] = "consent" - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + authorization_url = _build_authorization_url(auth_endpoint, auth_params) await self.context.redirect_handler(authorization_url) # Wait for callback diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c9..8025d07cf 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -12,6 +12,7 @@ from mcp.client.auth import OAuthClientProvider, PKCEParameters from mcp.client.auth.exceptions import OAuthFlowError +from mcp.client.auth.oauth2 import _build_authorization_url from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, @@ -2618,3 +2619,86 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +class TestAuthorizationEndpointWithQuery: + """Regression tests for #2776 - authorization_endpoint carrying query params.""" + + def test_build_authorization_url_no_existing_query(self): + url = _build_authorization_url( + "https://auth.example.com/authorize", + {"response_type": "code", "client_id": "abc"}, + ) + parsed = urlparse(url) + params = parse_qs(parsed.query) + assert parsed.path == "/authorize" + assert params["response_type"] == ["code"] + assert params["client_id"] == ["abc"] + # No malformed double "?" separator. + assert url.count("?") == 1 + + def test_build_authorization_url_preserves_existing_query(self): + # e.g. Salesforce advertises .../authorize?prompt=select_account + url = _build_authorization_url( + "https://test.salesforce.com/services/oauth2/authorize?prompt=select_account", + {"response_type": "code", "client_id": "abc"}, + ) + parsed = urlparse(url) + params = parse_qs(parsed.query) + assert parsed.path == "/services/oauth2/authorize" + # The server-provided param survives... + assert params["prompt"] == ["select_account"] + # ...alongside the flow-generated params. + assert params["response_type"] == ["code"] + assert params["client_id"] == ["abc"] + # Exactly one "?" - the old f-string produced "...?prompt=...?response_type=...". + assert url.count("?") == 1 + + def test_build_authorization_url_flow_params_win_on_conflict(self): + url = _build_authorization_url( + "https://auth.example.com/authorize?response_type=token", + {"response_type": "code"}, + ) + params = parse_qs(urlparse(url).query) + assert params["response_type"] == ["code"] + + @pytest.mark.anyio + async def test_perform_authorization_preserves_endpoint_query(self, oauth_provider: OAuthClientProvider): + """End-to-end: redirect URL stays valid when the endpoint has a query string.""" + oauth_provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://test.salesforce.com"), + authorization_endpoint=AnyHttpUrl( + "https://test.salesforce.com/services/oauth2/authorize?prompt=select_account" + ), + token_endpoint=AnyHttpUrl("https://test.salesforce.com/services/oauth2/token"), + ) + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + captured_url: str | None = None + captured_state: str | None = None + + async def capture_redirect(url: str) -> None: + nonlocal captured_url, captured_state + captured_url = url + captured_state = parse_qs(urlparse(url).query).get("state", [None])[0] + + async def mock_callback() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + oauth_provider.context.redirect_handler = capture_redirect + oauth_provider.context.callback_handler = mock_callback + + await oauth_provider._perform_authorization_code_grant() + + assert captured_url is not None + parsed = urlparse(captured_url) + params = parse_qs(parsed.query) + assert parsed.path == "/services/oauth2/authorize" + assert params["prompt"] == ["select_account"] + assert params["response_type"] == ["code"] + assert params["client_id"] == ["test_client_id"] + assert captured_url.count("?") == 1