Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import secrets
import string
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse

import anyio
import httpx
Expand Down Expand Up @@ -53,6 +53,22 @@
logger = logging.getLogger(__name__)


def _build_authorization_url(auth_endpoint: str, auth_params: Mapping[str, str | None]) -> str:
"""Build an authorization URL, preserving any query params already on the endpoint.

Servers may advertise an ``authorization_endpoint`` that already carries query
parameters (e.g. ``https://example.com/authorize?prompt=select_account``).
Naively appending ``?<params>`` would produce an invalid URL with two ``?``
separators, so the existing query is parsed and merged with ``auth_params``.
Flow-generated params take precedence on key conflicts; ``None`` values are
dropped rather than serialized as the literal string ``"None"``.
"""
parsed = urlparse(auth_endpoint)
merged_params = dict(parse_qsl(parsed.query, keep_blank_values=True))
merged_params.update({key: value for key, value in auth_params.items() if value is not None})
return urlunparse(parsed._replace(query=urlencode(merged_params)))


class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""

Expand Down Expand Up @@ -353,7 +369,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
if "offline_access" in self.context.client_metadata.scope.split():
auth_params["prompt"] = "consent"

authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
authorization_url = _build_authorization_url(auth_endpoint, auth_params)
await self.context.redirect_handler(authorization_url)

# Wait for callback
Expand Down
84 changes: 84 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mcp.client.auth import OAuthClientProvider, PKCEParameters
from mcp.client.auth.exceptions import OAuthFlowError
from mcp.client.auth.oauth2 import _build_authorization_url
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
Expand Down Expand Up @@ -2618,3 +2619,86 @@ async def callback_handler() -> tuple[str, str | None]:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass


class TestAuthorizationEndpointWithQuery:
"""Regression tests for #2776 - authorization_endpoint carrying query params."""

def test_build_authorization_url_no_existing_query(self):
url = _build_authorization_url(
"https://auth.example.com/authorize",
{"response_type": "code", "client_id": "abc"},
)
parsed = urlparse(url)
params = parse_qs(parsed.query)
assert parsed.path == "/authorize"
assert params["response_type"] == ["code"]
assert params["client_id"] == ["abc"]
# No malformed double "?" separator.
assert url.count("?") == 1

def test_build_authorization_url_preserves_existing_query(self):
# e.g. Salesforce advertises .../authorize?prompt=select_account
url = _build_authorization_url(
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account",
{"response_type": "code", "client_id": "abc"},
)
parsed = urlparse(url)
params = parse_qs(parsed.query)
assert parsed.path == "/services/oauth2/authorize"
# The server-provided param survives...
assert params["prompt"] == ["select_account"]
# ...alongside the flow-generated params.
assert params["response_type"] == ["code"]
assert params["client_id"] == ["abc"]
# Exactly one "?" - the old f-string produced "...?prompt=...?response_type=...".
assert url.count("?") == 1

def test_build_authorization_url_flow_params_win_on_conflict(self):
url = _build_authorization_url(
"https://auth.example.com/authorize?response_type=token",
{"response_type": "code"},
)
params = parse_qs(urlparse(url).query)
assert params["response_type"] == ["code"]

@pytest.mark.anyio
async def test_perform_authorization_preserves_endpoint_query(self, oauth_provider: OAuthClientProvider):
"""End-to-end: redirect URL stays valid when the endpoint has a query string."""
oauth_provider.context.oauth_metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://test.salesforce.com"),
authorization_endpoint=AnyHttpUrl(
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account"
),
token_endpoint=AnyHttpUrl("https://test.salesforce.com/services/oauth2/token"),
)
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)

captured_url: str | None = None
captured_state: str | None = None

async def capture_redirect(url: str) -> None:
nonlocal captured_url, captured_state
captured_url = url
captured_state = parse_qs(urlparse(url).query).get("state", [None])[0]

async def mock_callback() -> tuple[str, str | None]:
return "test_auth_code", captured_state

oauth_provider.context.redirect_handler = capture_redirect
oauth_provider.context.callback_handler = mock_callback

await oauth_provider._perform_authorization_code_grant()

assert captured_url is not None
parsed = urlparse(captured_url)
params = parse_qs(parsed.query)
assert parsed.path == "/services/oauth2/authorize"
assert params["prompt"] == ["select_account"]
assert params["response_type"] == ["code"]
assert params["client_id"] == ["test_client_id"]
assert captured_url.count("?") == 1
Loading