Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
from urllib.parse import quote, urljoin, urlparse

import anyio
import httpx
Expand Down Expand Up @@ -353,7 +353,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
if "offline_access" in self.context.client_metadata.scope.split():
auth_params["prompt"] = "consent"

authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
authorization_url = str(httpx.URL(auth_endpoint).copy_merge_params(auth_params))
await self.context.redirect_handler(authorization_url)

# Wait for callback
Expand Down
52 changes: 47 additions & 5 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,52 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
assert oauth_provider.context.token_expiry_time is not None

@pytest.mark.anyio
async def test_authorization_endpoint_preserves_existing_query_params(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
):
captured_auth_url: str | None = None
captured_state: str | None = None

async def redirect_handler(url: str) -> None:
nonlocal captured_auth_url, captured_state
captured_auth_url = url
captured_state = parse_qs(urlparse(url).query)["state"][0]

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", captured_state

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)
provider.context.oauth_metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize?prompt=select_account"),
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
)
provider.context.client_info = OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)

auth_code, code_verifier = await provider._perform_authorization_code_grant()

assert auth_code == "test_auth_code"
assert code_verifier
assert captured_auth_url is not None
parsed = urlparse(captured_auth_url)
params = parse_qs(parsed.query)
assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == "https://auth.example.com/authorize"
assert params["prompt"] == ["select_account"]
assert params["response_type"] == ["code"]
assert params["client_id"] == ["test_client"]
assert params["redirect_uri"] == ["http://localhost:3030/callback"]

@pytest.mark.anyio
async def test_auth_flow_no_unnecessary_retry_after_oauth(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
Expand Down Expand Up @@ -1350,13 +1396,9 @@ async def test_403_insufficient_scope_updates_scope_from_header(
async def capture_redirect(url: str) -> None:
nonlocal redirect_captured, captured_state
redirect_captured = True
# Verify the new scope is included in authorization URL
assert "scope=admin%3Awrite+admin%3Adelete" in url or "scope=admin:write+admin:delete" in url.replace(
"%3A", ":"
).replace("+", " ")
# Extract state from redirect URL
parsed = urlparse(url)
params = parse_qs(parsed.query)
assert params["scope"] == ["admin:write admin:delete"]
captured_state = params.get("state", [None])[0]

oauth_provider.context.redirect_handler = capture_redirect
Expand Down
Loading