Skip to content
Open
57 changes: 44 additions & 13 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from anyio.abc import TaskGroup
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from mcp_types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
INVALID_REQUEST,
METHOD_NOT_FOUND,
Expand Down Expand Up @@ -43,6 +44,16 @@
StreamWriter = ContextSendStream[SessionMessageOrError]
StreamReader = ContextReceiveStream[SessionMessage]


async def _send_or_ignore_closed(read_stream_writer: StreamWriter, message: SessionMessageOrError) -> bool:
try:
await read_stream_writer.send(message)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("Read stream closed before Streamable HTTP message could be delivered", exc_info=True)
return False
return True


MCP_SESSION_ID = "mcp-session-id"
LAST_EVENT_ID = "last-event-id"

Expand Down Expand Up @@ -156,17 +167,17 @@ async def _handle_sse_event(
# Otherwise, return False to continue listening
return isinstance(message, JSONRPCResponse | JSONRPCError)

# Forwarding to a closed read stream lands here when the caller cancels mid-SSE
# (BrokenResourceError, not a parse failure); coverage is timing-dependent in the
# streaming story's modern HTTP cancellation leg.
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("Read stream closed while forwarding SSE message", exc_info=True)
return True
except Exception as exc: # pragma: lax no cover
logger.exception("Error parsing SSE message")
if original_request_id is not None:
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await read_stream_writer.send(error_msg)
await _send_or_ignore_closed(read_stream_writer, error_msg)
return True
await read_stream_writer.send(exc)
await _send_or_ignore_closed(read_stream_writer, exc)
return False
else: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}")
Expand Down Expand Up @@ -377,10 +388,16 @@ async def _handle_sse_response(
except Exception:
logger.debug("SSE stream ended", exc_info=True) # pragma: lax no cover

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
# Stream ended without a terminal response/error. If the server provided an event id,
# try resuming; otherwise fail the request instead of hanging forever.
if last_event_id is None:
error_data = ErrorData(code=CONNECTION_CLOSED, message="SSE stream disconnected before response completed")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await _send_or_ignore_closed(ctx.read_stream_writer, error_msg)
return

logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)

async def _handle_reconnection(
self,
Expand All @@ -391,7 +408,16 @@ async def _handle_reconnection(
) -> None:
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
if attempt >= MAX_RECONNECTION_ATTEMPTS:
Comment thread
Epochex marked this conversation as resolved.
assert isinstance(ctx.session_message.message, JSONRPCRequest)
original_request_id = ctx.session_message.message.id
error_data = ErrorData(
code=CONNECTION_CLOSED,
message="SSE stream disconnected and could not be resumed",
data={"last_event_id": last_event_id},
)
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await _send_or_ignore_closed(ctx.read_stream_writer, error_msg)
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return

Expand All @@ -415,12 +441,15 @@ async def _handle_reconnection(
# Track for potential further reconnection
reconnect_last_event_id: str = last_event_id
reconnect_retry_ms = retry_interval_ms
made_progress = False

async for sse in event_source.aiter_sse():
if sse.id: # pragma: no branch
reconnect_last_event_id = sse.id
if sse.retry is not None:
reconnect_retry_ms = sse.retry
if sse.event == "message" and bool(sse.data):
made_progress = True

is_complete = await self._handle_sse_event(
sse,
Expand All @@ -432,10 +461,12 @@ async def _handle_reconnection(
await event_source.response.aclose()
return

# Stream ended again without response - reconnect again (reset attempt counter)
# Stream ended again without response - reconnect again. Only reset
# the retry counter when the resumed stream delivered real data.
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
except Exception as e: # pragma: no cover
next_attempt = 0 if made_progress else attempt + 1
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, next_attempt)
except Exception as e:
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
Expand Down
109 changes: 107 additions & 2 deletions tests/client/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@

import base64
import json
from types import SimpleNamespace

import anyio
import httpx
import pytest
from inline_snapshot import snapshot
from mcp_types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse
from mcp_types import (
CONNECTION_CLOSED,
METHOD_NOT_FOUND,
JSONRPCError,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
)

from mcp.client.streamable_http import streamable_http_client
from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport, streamable_http_client
from mcp.shared._context_streams import create_context_streams
from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER, encode_header_value
from mcp.shared.message import ClientMessageMetadata, SessionMessage

Expand Down Expand Up @@ -52,6 +61,102 @@ def test_mcp_name_header_values_are_base64_wrapped_when_unsafe_for_an_http_field
assert encoded == raw


@pytest.mark.anyio
async def test_sse_response_disconnect_before_any_event_id_fails_request() -> None:
transport = StreamableHTTPTransport("http://example.com/mcp")
async with httpx.AsyncClient() as client:
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](1)
request = JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "noop", "arguments": {}})
ctx = RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(request),
metadata=None,
read_stream_writer=read_stream_writer,
)
response = httpx.Response(200, headers={"content-type": "text/event-stream"}, content=b"")

async with read_stream_writer, read_stream:
await transport._handle_sse_response(response, ctx)
with anyio.fail_after(5):
message = await read_stream.receive()

assert isinstance(message, SessionMessage)
assert isinstance(message.message, JSONRPCError)
assert message.message.id == 1
assert message.message.error.code == CONNECTION_CLOSED


@pytest.mark.anyio
async def test_reconnection_empty_streams_count_toward_max_attempts(monkeypatch: pytest.MonkeyPatch) -> None:
class PrimingOnlyEventSource:
def __init__(self) -> None:
self.response = httpx.Response(200)

async def __aenter__(self) -> "PrimingOnlyEventSource":
nonlocal reconnect_attempts
reconnect_attempts += 1
return self

async def __aexit__(self, *args: object) -> None:
return None

async def aiter_sse(self) -> object:
yield SimpleNamespace(event="message", data="", id=f"event-{reconnect_attempts}", retry=0)

def connect_sse(*args: object, **kwargs: object) -> PrimingOnlyEventSource:
return PrimingOnlyEventSource()

reconnect_attempts = 0
monkeypatch.setattr(
"mcp.client.streamable_http.aconnect_sse",
connect_sse,
)

transport = StreamableHTTPTransport("http://example.com/mcp")
async with httpx.AsyncClient() as client:
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](1)
request = JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "noop", "arguments": {}})
ctx = RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(request),
metadata=None,
read_stream_writer=read_stream_writer,
)

async with read_stream_writer, read_stream:
with anyio.fail_after(5):
await transport._handle_reconnection(ctx, "event-1", retry_interval_ms=0)
message = await read_stream.receive()

assert reconnect_attempts == 2
assert isinstance(message, SessionMessage)
assert isinstance(message.message, JSONRPCError)
assert message.message.id == 1
assert message.message.error.code == CONNECTION_CLOSED


@pytest.mark.anyio
async def test_sse_response_disconnect_ignores_closed_read_stream() -> None:
transport = StreamableHTTPTransport("http://example.com/mcp")
async with httpx.AsyncClient() as client:
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](1)
request = JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "noop", "arguments": {}})
ctx = RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(request),
metadata=None,
read_stream_writer=read_stream_writer,
)
response = httpx.Response(200, headers={"content-type": "text/event-stream"}, content=b"")

async with read_stream_writer, read_stream:
await read_stream.aclose()
await transport._handle_sse_response(response, ctx)


@pytest.mark.anyio
async def test_post_request_merges_per_message_metadata_headers() -> None:
"""`ClientMessageMetadata.headers` on a `SessionMessage` are merged into the outgoing POST headers
Expand Down
76 changes: 76 additions & 0 deletions tests/interaction/transports/test_hosting_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,82 @@ async def call() -> None:
assert received == snapshot(["before close", "after close"])


@requirement("hosting:resume:close-stream")
@requirement("transport:streamable-http:resumability")
@requirement("client-transport:http:reconnect-post-priming")
@requirement("client-transport:http:reconnect-retry-value")
async def test_a_call_whose_stream_closes_and_cannot_be_resumed_fails_instead_of_hanging() -> None:
"""If a resumable response stream disconnects and the server session is gone, the client fails
the request instead of hanging forever.

The server closes the call's SSE stream after emitting one related notification. The test then
deletes the active server-side session to force the client's reconnect GET to return 404.
Without a terminal response/error on the read stream, ClientSession.send_request waits forever
(read timeout defaults to None). The transport must surface a request-scoped error when it
gives up reconnecting.
"""
reconnect_attempted = anyio.Event()
allow_exit = anyio.Event()
done = anyio.Event()
raised: list[BaseException] = []
manager_ref = None
deleted_session = False

mcp = MCPServer("resumable")

@mcp.tool()
async def interrupt(ctx: Context) -> str:
await ctx.info("before close") # pyright: ignore[reportDeprecated]
await ctx.close_sse_stream()
await allow_exit.wait()
return "unreachable"

async def record_request(request: httpx.Request) -> None:
nonlocal deleted_session
if request.method != "GET":
return
if request.headers.get("last-event-id") is None:
return
reconnect_attempted.set()
if deleted_session or manager_ref is None:
return
session_ids = list(manager_ref._server_instances.keys())
if session_ids: # pragma: no branch
del manager_ref._server_instances[session_ids[0]]
deleted_session = True

async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0, on_request=record_request) as (
http,
manager,
):
manager_ref = manager
with anyio.fail_after(5): # pragma: no branch
async with ( # pragma: no branch
streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r, w),
ClientSession(r, w) as session,
anyio.create_task_group() as tg,
):
await session.initialize()

async def call() -> None:
try:
await session.call_tool("interrupt", {})
except BaseException as exc:
raised.append(exc)
finally:
done.set()

tg.start_soon(call)
await reconnect_attempted.wait()
await done.wait()
allow_exit.set()
tg.cancel_scope.cancel()

assert len(raised) == 1
assert isinstance(raised[0], Exception)
assert "disconnected" in str(raised[0]).lower()


@requirement("client-transport:http:resume-stream-api")
async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None:
"""A resumption token captured via on_resumption_token_update on one connection lets a fresh
Expand Down
Loading