diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..8f8907f47 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -344,39 +344,41 @@ async def _aggregate_components(self, server_info: types.Implementation, session tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} + # Query capabilities negotiated during initialize(). + capabilities = session.initialize_result.capabilities if session.initialize_result is not None else None # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch prompts: {err}") + if capabilities is None or capabilities.prompts is not None: + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch resources: {err}") + if capabilities is None or capabilities.resources is not None: + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch tools: {err}") - - # Clean up exit stack for session if we couldn't retrieve anything - # from the server. + if capabilities is None or capabilities.tools is not None: + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch tools: {err}") if not any((prompts_temp, resources_temp, tools_temp)): del self._session_exit_stacks[session] # pragma: no cover diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..e2f54d2d8 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -125,6 +125,93 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_session.list_prompts.assert_awaited_once() +@pytest.mark.anyio +@pytest.mark.anyio +async def test_client_session_group_skips_unsupported_capabilities( + mock_exit_stack: contextlib.AsyncExitStack, +): + """Only query capabilities advertised by the server.""" + + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "ToolsOnlyServer" + + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "ping" + + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + capabilities = mock.Mock() + capabilities.tools = object() + capabilities.prompts = None + capabilities.resources = None + + initialize_result = mock.Mock() + initialize_result.capabilities = capabilities + + mock_session.initialize_result = initialize_result + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + + await group._aggregate_components( + mock_server_info, + mock_session, + ) + + mock_session.list_tools.assert_awaited_once() + mock_session.list_prompts.assert_not_awaited() + mock_session.list_resources.assert_not_awaited() + + assert "ping" in group.tools + + +@pytest.mark.anyio +@pytest.mark.anyio +async def test_client_session_group_skips_unsupported_tools( + mock_exit_stack: contextlib.AsyncExitStack, +): + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "TestServer" + + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + + mock_prompt = mock.Mock(spec=types.Prompt) + mock_prompt.name = "prompt" + + mock_resource = mock.Mock(spec=types.Resource) + mock_resource.name = "resource" + + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource]) + mock_session.list_tools.return_value = mock.AsyncMock(tools=[]) + + capabilities = mock.Mock() + capabilities.tools = None + capabilities.prompts = object() + capabilities.resources = object() + + initialize_result = mock.Mock() + initialize_result.capabilities = capabilities + mock_session.initialize_result = initialize_result + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + + await group._aggregate_components( + mock_server_info, + mock_session, + ) + + mock_session.list_tools.assert_not_awaited() + mock_session.list_prompts.assert_awaited_once() + mock_session.list_resources.assert_awaited_once() + + assert "prompt" in group.prompts + assert "resource" in group.resources + + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook."""