From 5f14f6d8d3d2c0cbe43e0db4d5e9874f733d02a0 Mon Sep 17 00:00:00 2001 From: bobby-nandigam Date: Sun, 7 Jun 2026 12:20:00 +0530 Subject: [PATCH 1/3] Respect negotiated capabilities in ClientSessionGroup --- src/mcp/client/session_group.py | 62 ++++++++++++++++-------------- tests/client/test_session_group.py | 43 +++++++++++++++++++++ 2 files changed, 77 insertions(+), 28 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..04aa5ed895 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -344,39 +344,45 @@ async def _aggregate_components(self, server_info: types.Implementation, session tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} + # Query capabilities negotiated during initialize(). + capabilities = ( + session.initialize_result.capabilities + if session.initialize_result is not None + else None + ) # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch prompts: {err}") + if capabilities is None or capabilities.prompts is not None: + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch resources: {err}") + if capabilities is None or capabilities.resources is not None: + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch tools: {err}") - - # Clean up exit stack for session if we couldn't retrieve anything - # from the server. + if capabilities is None or capabilities.tools is not None: + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch tools: {err}") if not any((prompts_temp, resources_temp, tools_temp)): del self._session_exit_stacks[session] # pragma: no cover diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..ec7b6002b0 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -125,6 +125,49 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_session.list_prompts.assert_awaited_once() +@pytest.mark.anyio + +@pytest.mark.anyio +async def test_client_session_group_skips_unsupported_capabilities( + mock_exit_stack: contextlib.AsyncExitStack, +): + """Only query capabilities advertised by the server.""" + + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "ToolsOnlyServer" + + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "ping" + + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + capabilities = mock.Mock() + capabilities.tools = object() + capabilities.prompts = None + capabilities.resources = None + + initialize_result = mock.Mock() + initialize_result.capabilities = capabilities + + mock_session.initialize_result = initialize_result + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + + await group._aggregate_components( + mock_server_info, + mock_session, + ) + + mock_session.list_tools.assert_awaited_once() + mock_session.list_prompts.assert_not_awaited() + mock_session.list_resources.assert_not_awaited() + + assert "ping" in group.tools + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook.""" From 9bfbe1cf0d425c1329d91b2f786d74cdc482e882 Mon Sep 17 00:00:00 2001 From: bobby-nandigam Date: Sun, 7 Jun 2026 12:41:56 +0530 Subject: [PATCH 2/3] style: apply ruff formatting --- src/mcp/client/session_group.py | 6 +----- tests/client/test_session_group.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 04aa5ed895..8f8907f47c 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -345,11 +345,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session tool_to_session_temp: dict[str, mcp.ClientSession] = {} # Query capabilities negotiated during initialize(). - capabilities = ( - session.initialize_result.capabilities - if session.initialize_result is not None - else None - ) + capabilities = session.initialize_result.capabilities if session.initialize_result is not None else None # Query the server for its prompts and aggregate to list. if capabilities is None or capabilities.prompts is not None: try: diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index ec7b6002b0..6b85ba5082 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -126,7 +126,6 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli @pytest.mark.anyio - @pytest.mark.anyio async def test_client_session_group_skips_unsupported_capabilities( mock_exit_stack: contextlib.AsyncExitStack, @@ -168,6 +167,7 @@ async def test_client_session_group_skips_unsupported_capabilities( assert "ping" in group.tools + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook.""" From d2590f5c190c9b69ed309114edcbe7c6f7c17577 Mon Sep 17 00:00:00 2001 From: bobby-nandigam Date: Sun, 7 Jun 2026 13:00:15 +0530 Subject: [PATCH 3/3] test: restore anyio marker and add capability coverage tests --- tests/client/test_session_group.py | 44 ++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6b85ba5082..e2f54d2d8b 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -168,6 +168,50 @@ async def test_client_session_group_skips_unsupported_capabilities( assert "ping" in group.tools +@pytest.mark.anyio +@pytest.mark.anyio +async def test_client_session_group_skips_unsupported_tools( + mock_exit_stack: contextlib.AsyncExitStack, +): + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "TestServer" + + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + + mock_prompt = mock.Mock(spec=types.Prompt) + mock_prompt.name = "prompt" + + mock_resource = mock.Mock(spec=types.Resource) + mock_resource.name = "resource" + + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource]) + mock_session.list_tools.return_value = mock.AsyncMock(tools=[]) + + capabilities = mock.Mock() + capabilities.tools = None + capabilities.prompts = object() + capabilities.resources = object() + + initialize_result = mock.Mock() + initialize_result.capabilities = capabilities + mock_session.initialize_result = initialize_result + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + + await group._aggregate_components( + mock_server_info, + mock_session, + ) + + mock_session.list_tools.assert_not_awaited() + mock_session.list_prompts.assert_awaited_once() + mock_session.list_resources.assert_awaited_once() + + assert "prompt" in group.prompts + assert "resource" in group.resources + + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook."""