diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..57577c6621 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -147,7 +147,7 @@ def __init__( self._session_exit_stacks = {} self._component_name_hook = component_name_hook - async def __aenter__(self) -> Self: # pragma: no cover + async def __aenter__(self) -> Self: # Enter the exit stack only if we created it ourselves if self._owns_exit_stack: await self._exit_stack.__aenter__() @@ -158,7 +158,7 @@ async def __aexit__( _exc_type: type[BaseException] | None, _exc_val: BaseException | None, _exc_tb: TracebackType | None, - ) -> bool | None: # pragma: no cover + ) -> bool | None: """Closes session exit stacks and main exit stack upon completion.""" # Only close the main exit stack if we created it @@ -323,7 +323,7 @@ async def _establish_session( await self._exit_stack.enter_async_context(session_stack) return result.server_info, session - except Exception: # pragma: no cover + except Exception: # If anything during this setup fails, ensure the session-specific # stack is closed. await session_stack.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..73d5efbcc6 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -467,17 +467,26 @@ async def _handle_message(session_message: SessionMessage) -> None: read_stream_writer=read_stream_writer, ) - async def handle_request_async(): + async def send_message() -> None: if is_resumption: await self._handle_resumption_request(ctx) else: await self._handle_post_request(ctx) + async def handle_request_async(request: JSONRPCRequest) -> None: + try: + await send_message() + except httpx.TransportError as exc: + logger.debug("Error handling request", exc_info=True) + error_data = ErrorData(code=INTERNAL_ERROR, message=f"Transport error: {exc}") + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request.id, error=error_data)) + await ctx.read_stream_writer.send(error_msg) + # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): - tg.start_soon(handle_request_async) + tg.start_soon(handle_request_async, message) else: - await handle_request_async() + await send_message() async for session_message in write_stream_reader: sender_ctx = write_stream_reader.last_context diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..165a9c7337 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -51,6 +51,27 @@ def test_client_session_group_component_properties(): assert mcp_session_group.tools == {"my_tool": mock_tool} +@pytest.mark.anyio +async def test_client_session_group_context_manager_closes_session_stacks_with_external_stack(): + class SessionStack(contextlib.AsyncExitStack): + def __init__(self) -> None: + super().__init__() + self.closed = False + + async def aclose(self) -> None: + self.closed = True + await super().aclose() + + session_stack = SessionStack() + group = ClientSessionGroup(exit_stack=contextlib.AsyncExitStack()) + group._session_exit_stacks[mock.Mock(spec=mcp.ClientSession)] = session_stack + + async with group as entered: + assert entered is group + + assert session_stack.closed + + @pytest.mark.anyio async def test_client_session_group_call_tool(): # --- Mock Dependencies --- @@ -278,6 +299,25 @@ async def test_client_session_group_disconnect_non_existent_server(): await group.disconnect_from_server(session) +@pytest.mark.anyio +async def test_client_session_group_streamable_http_connection_error_surfaces() -> None: + async def fail_request(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("offline", request=request) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(fail_request)) + + with mock.patch("mcp.client.session_group.create_mcp_http_client", return_value=http_client): + async with ClientSessionGroup() as group: + with pytest.raises(MCPError) as excinfo: # pragma: no branch + await group.connect_to_server( + StreamableHttpParameters(url="http://example.test/mcp"), + ClientSessionParameters(read_timeout_seconds=2), + ) + + assert excinfo.value.error.code == types.INTERNAL_ERROR + assert excinfo.value.error.message == "Transport error: offline" + + # TODO(Marcelo): This is horrible. We should drop this test. @pytest.mark.anyio @pytest.mark.parametrize(