Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,39 +344,41 @@ async def _aggregate_components(self, server_info: types.Implementation, session
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}

# Query capabilities negotiated during initialize().
capabilities = session.initialize_result.capabilities if session.initialize_result is not None else None
# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")
if capabilities is None or capabilities.prompts is not None:
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")

# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")
if capabilities is None or capabilities.resources is not None:
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")

# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")

# Clean up exit stack for session if we couldn't retrieve anything
# from the server.
if capabilities is None or capabilities.tools is not None:
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")
if not any((prompts_temp, resources_temp, tools_temp)):
del self._session_exit_stacks[session] # pragma: no cover

Expand Down
87 changes: 87 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,93 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
mock_session.list_prompts.assert_awaited_once()


@pytest.mark.anyio
@pytest.mark.anyio
async def test_client_session_group_skips_unsupported_capabilities(
mock_exit_stack: contextlib.AsyncExitStack,
):
"""Only query capabilities advertised by the server."""

mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "ToolsOnlyServer"

mock_session = mock.AsyncMock(spec=mcp.ClientSession)

mock_tool = mock.Mock(spec=types.Tool)
mock_tool.name = "ping"

mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])

capabilities = mock.Mock()
capabilities.tools = object()
capabilities.prompts = None
capabilities.resources = None

initialize_result = mock.Mock()
initialize_result.capabilities = capabilities

mock_session.initialize_result = initialize_result

group = ClientSessionGroup(exit_stack=mock_exit_stack)

await group._aggregate_components(
mock_server_info,
mock_session,
)

mock_session.list_tools.assert_awaited_once()
mock_session.list_prompts.assert_not_awaited()
mock_session.list_resources.assert_not_awaited()

assert "ping" in group.tools


@pytest.mark.anyio
@pytest.mark.anyio
async def test_client_session_group_skips_unsupported_tools(
mock_exit_stack: contextlib.AsyncExitStack,
):
mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "TestServer"

mock_session = mock.AsyncMock(spec=mcp.ClientSession)

mock_prompt = mock.Mock(spec=types.Prompt)
mock_prompt.name = "prompt"

mock_resource = mock.Mock(spec=types.Resource)
mock_resource.name = "resource"

mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt])
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource])
mock_session.list_tools.return_value = mock.AsyncMock(tools=[])

capabilities = mock.Mock()
capabilities.tools = None
capabilities.prompts = object()
capabilities.resources = object()

initialize_result = mock.Mock()
initialize_result.capabilities = capabilities
mock_session.initialize_result = initialize_result

group = ClientSessionGroup(exit_stack=mock_exit_stack)

await group._aggregate_components(
mock_server_info,
mock_session,
)

mock_session.list_tools.assert_not_awaited()
mock_session.list_prompts.assert_awaited_once()
mock_session.list_resources.assert_awaited_once()

assert "prompt" in group.prompts
assert "resource" in group.resources


@pytest.mark.anyio
async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack):
"""Test connecting with a component name hook."""
Expand Down
Loading