Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,19 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")

response_complete = False
async for sse in event_source.aiter_sse(): # pragma: no branch
if response_complete:
continue

is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
break
response_complete = True

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
Expand Down Expand Up @@ -340,9 +343,13 @@ async def _handle_sse_response(
assert isinstance(ctx.session_message.message, JSONRPCRequest)
original_request_id = ctx.session_message.message.id

response_complete = False
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse(): # pragma: no branch
if response_complete:
continue

# Track last event ID for potential reconnection
if sse.id:
last_event_id = sse.id
Expand All @@ -359,13 +366,15 @@ async def _handle_sse_response(
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning response/error
# break the loop
# keep draining the response to EOF so the HTTP connection can be reused.
if is_complete:
await response.aclose()
return # Normal completion, no reconnect needed
response_complete = True
except Exception:
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover

if response_complete:
return # Normal completion, no reconnect needed

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
logger.info("SSE stream disconnected, reconnecting...")
Expand Down Expand Up @@ -405,7 +414,11 @@ async def _handle_reconnection(
reconnect_last_event_id: str = last_event_id
reconnect_retry_ms = retry_interval_ms

response_complete = False
async for sse in event_source.aiter_sse():
if response_complete:
continue

if sse.id: # pragma: no branch
reconnect_last_event_id = sse.id
if sse.retry is not None:
Expand All @@ -418,13 +431,15 @@ async def _handle_reconnection(
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
return
response_complete = True

if response_complete:
return

# Stream ended again without response - reconnect again (reset attempt counter)
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
except Exception as e: # pragma: no cover
except Exception as e: # pragma: lax no cover
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
Expand Down
4 changes: 2 additions & 2 deletions tests/interaction/transports/test_hosting_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
capture = ClientMessageMetadata(on_resumption_token_update=on_token)

async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager):
with anyio.fail_after(5): # pragma: no branch
with anyio.fail_after(5): # pragma: lax no cover
async with ( # pragma: no branch
streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1),
ClientSession(r1, w1, logging_callback=collect) as first,
Expand All @@ -357,7 +357,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION
tg.cancel_scope.cancel()

with anyio.fail_after(5): # pragma: no branch
with anyio.fail_after(5): # pragma: lax no cover
release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
# init priming + init response + call priming + "first" + "second" + result = 6 stored events.
await store.wait_until_stored(6)
Expand Down
122 changes: 122 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from starlette.requests import Request
from starlette.routing import Mount

import mcp.client.streamable_http as streamable_http
from mcp import MCPError, types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
Expand Down Expand Up @@ -139,6 +140,39 @@ async def replay_events_after( # pragma: no cover
return target_stream_id


class FakeStreamResponse(httpx.Response):
def __init__(self) -> None:
super().__init__(
200,
request=httpx.Request("POST", "http://localhost:8000/mcp"),
)
self.close_count = 0

async def aclose(self) -> None: # pragma: no cover
self.close_count += 1


class FakeEventSource:
def __init__(self, events: list[ServerSentEvent]) -> None:
self.response = FakeStreamResponse()
self.events = events
self.seen = 0

async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]:
for event in self.events:
self.seen += 1
yield event


def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent:
return ServerSentEvent(
event="message",
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
id=event_id,
retry=None,
)


@dataclass
class ServerState:
lock: anyio.Event = field(default_factory=anyio.Event)
Expand Down Expand Up @@ -1803,6 +1837,94 @@ async def test_handle_sse_event_skips_empty_data():
await read_stream.aclose()


@pytest.mark.anyio
async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch):
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
response = FakeStreamResponse()
event_source = FakeEventSource(
[
jsonrpc_response_event("request-1", "event-1"),
ServerSentEvent(event="message", data="", id="event-2", retry=None),
]
)

def event_source_factory(_response: httpx.Response) -> FakeEventSource:
return event_source

monkeypatch.setattr(streamable_http, "EventSource", event_source_factory)

write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
try:
async with httpx.AsyncClient() as client:
ctx = streamable_http.RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
),
metadata=None,
read_stream_writer=write_stream,
)

await transport._handle_sse_response(response, ctx)

received = await read_stream.receive()
assert isinstance(received, SessionMessage)
assert isinstance(received.message, types.JSONRPCResponse)
assert received.message.id == "request-1"
assert event_source.seen == 2
assert response.close_count == 0
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch):
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
event_source = FakeEventSource(
[
jsonrpc_response_event("request-1", "event-2"),
ServerSentEvent(event="message", data="", id="event-3", retry=None),
]
)

async def sleep_noop(_delay: float) -> None:
pass

@asynccontextmanager
async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]:
yield event_source

monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop)
monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse)

write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
try:
async with httpx.AsyncClient() as client:
ctx = streamable_http.RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
),
metadata=None,
read_stream_writer=write_stream,
)

await transport._handle_reconnection(ctx, last_event_id="event-1")

received = await read_stream.receive()
assert isinstance(received, SessionMessage)
assert isinstance(received.message, types.JSONRPCResponse)
assert received.message.id == "request-1"
assert event_source.seen == 2
assert event_source.response.close_count == 0
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_priming_event_not_sent_for_old_protocol_version():
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""
Expand Down
Loading