diff --git a/src/mcp/client/_probe.py b/src/mcp/client/_probe.py index 39a5c5296..15935870a 100644 --- a/src/mcp/client/_probe.py +++ b/src/mcp/client/_probe.py @@ -38,7 +38,7 @@ def _parse_supported(data: Any) -> list[str] | None: return None -async def negotiate_auto(session: ClientSession) -> None: +async def negotiate_auto(session: ClientSession, protocol_version: str | None = None) -> None: """Drive the ``mode='auto'`` connect-time policy on ``session``. Probes ``server/discover`` once (twice if the server names a mutual @@ -65,14 +65,20 @@ async def negotiate_auto(session: ClientSession) -> None: continue if supported is not None and not any(v in HANDSHAKE_PROTOCOL_VERSIONS for v in supported): raise # server is modern-only and disjoint — real incompatibility - await session.initialize() # every other rpc-error → legacy (the denylist) + if protocol_version is not None: + await session.initialize(protocol_version=protocol_version) + else: + await session.initialize() # every other rpc-error → legacy (the denylist) return # any other exception (httpx.TransportError, ConnectionError, anyio errors, # RuntimeError from adopt) → propagate try: result = types.DiscoverResult.model_validate(raw) except ValidationError: - await session.initialize() # unparseable result → not modern evidence + if protocol_version is not None: + await session.initialize(protocol_version=protocol_version) + else: + await session.initialize() # unparseable result → not modern evidence return session.adopt(result) return diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 308f28d1c..862075ac3 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -200,6 +200,8 @@ async def main(): """Callback for handling elicitation requests.""" _entered: bool = field(init=False, default=False) + protocol_version_override: str | None = None + """The protocol version to request during initialization. Defaults to the latest version.""" _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _connect: _Connector = field(init=False, repr=False, compare=False) @@ -250,9 +252,12 @@ async def __aenter__(self) -> Client: session = await exit_stack.enter_async_context(session) if self.mode == "legacy": - await session.initialize() + if self.protocol_version_override is not None: + await session.initialize(protocol_version=self.protocol_version_override) + else: + await session.initialize() elif self.mode == "auto": - await negotiate_auto(session) + await negotiate_auto(session, protocol_version=self.protocol_version_override) else: session.adopt(self.prior_discover or _synthesize_discover(self.mode)) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0c6e0270c..8705012e8 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -350,13 +350,13 @@ def _build_capabilities(self) -> types.ClientCapabilities: ) return types.ClientCapabilities(sampling=sampling, elicitation=elicitation, experimental=None, roots=roots) - async def initialize(self) -> types.InitializeResult: + async def initialize(self, protocol_version: str = LATEST_HANDSHAKE_VERSION) -> types.InitializeResult: if self._initialize_result is not None: return self._initialize_result result = await self.send_request( types.InitializeRequest( params=types.InitializeRequestParams( - protocol_version=LATEST_HANDSHAKE_VERSION, + protocol_version=protocol_version, capabilities=self._build_capabilities(), client_info=self._client_info, ), diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 40f023259..d922ed46d 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -80,6 +80,7 @@ class ClientSessionParameters: logging_callback: LoggingFnT | None = None message_handler: MessageHandlerFnT | None = None client_info: types.Implementation | None = None + protocol_version: str | None = None class ClientSessionGroup: @@ -352,7 +353,10 @@ async def _establish_session( ) ) - result = await session.initialize() + if session_params.protocol_version is not None: + result = await session.initialize(protocol_version=session_params.protocol_version) + else: + result = await session.initialize() # Session successfully initialized. # Store its stack and register the stack with the main group stack. diff --git a/tests/client/test_client.py b/tests/client/test_client.py index f869d1f1b..2e8ec8bc4 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -126,6 +126,13 @@ async def test_client_exposes_negotiated_protocol_version(app: MCPServer): assert client.protocol_version == LATEST_HANDSHAKE_VERSION +async def test_client_custom_protocol_version(app: MCPServer): + """Test that the client negotiates a custom protocol version when configured.""" + async with Client(app, mode="legacy", protocol_version_override="2024-11-05") as client: + assert client.protocol_version == "2024-11-05" + assert client.server_info.name == "test" + + async def test_client_with_simple_server(simple_server: Server): """Test that from_server works with a basic Server instance.""" async with Client(simple_server) as client: diff --git a/tests/client/test_probe.py b/tests/client/test_probe.py index 34a347fa7..06fcae69d 100644 --- a/tests/client/test_probe.py +++ b/tests/client/test_probe.py @@ -52,6 +52,7 @@ def __init__(self, *script: dict[str, Any] | Exception) -> None: self._script: list[dict[str, Any] | Exception] = list(script) self.probed_at: list[str] = [] self.initialized: bool = False + self.initialize_version: str | None = None self.adopted: types.DiscoverResult | None = None async def send_discover(self, version: str) -> dict[str, Any]: @@ -61,16 +62,17 @@ async def send_discover(self, version: str) -> dict[str, Any]: raise step return step - async def initialize(self) -> None: + async def initialize(self, protocol_version: str | None = None) -> None: self.initialized = True + self.initialize_version = protocol_version def adopt(self, result: types.DiscoverResult) -> None: self.adopted = result -async def _negotiate(session: _StubSession) -> None: +async def _negotiate(session: _StubSession, protocol_version: str | None = None) -> None: """Drive `negotiate_auto` against the stub; cast at one seam so the tests stay suppression-free.""" - await negotiate_auto(cast("ClientSession", session)) + await negotiate_auto(cast("ClientSession", session), protocol_version=protocol_version) def _discover_dict(versions: list[str] | None = None) -> dict[str, Any]: @@ -240,3 +242,20 @@ def test_parse_supported_returns_none_for_anything_not_shaped_like_the_spec_erro """`_parse_supported` returns the `supported` list when `error.data` validates as `UnsupportedProtocolVersionErrorData`, and `None` otherwise — never raises.""" assert _parse_supported(data) == expected + + +async def test_negotiate_auto_mcp_error_with_custom_protocol_version() -> None: + """Test that negotiate_auto initializes with a custom protocol version when discover returns an MCPError.""" + session = _StubSession(MCPError(code=METHOD_NOT_FOUND, message="nope")) + await _negotiate(session, protocol_version="2024-11-05") + assert session.initialized + assert session.initialize_version == "2024-11-05" + + +async def test_negotiate_auto_validation_error_with_custom_protocol_version() -> None: + """Test that negotiate_auto initializes with a custom protocol version when discover returns unparseable result.""" + session = _StubSession({"not": "a discover result"}) + await _negotiate(session, protocol_version="2024-11-05") + assert session.initialized + assert session.initialize_version == "2024-11-05" + diff --git a/tests/client/test_session.py b/tests/client/test_session.py index b66ca1bba..f68347c2d 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -153,6 +153,90 @@ async def message_handler( # pragma: no cover assert isinstance(initialized_notification, InitializedNotification) +@pytest.mark.anyio +async def test_client_session_initialize_custom_protocol_version(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + initialized_notification = None + result = None + + async def mock_server(): + nonlocal initialized_notification + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request, JSONRPCRequest) + request = client_request_adapter.validate_python( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request, InitializeRequest) + assert request.params.protocol_version == "2024-11-05" + + result = InitializeResult( + protocol_version="2024-11-05", + capabilities=ServerCapabilities( + logging=None, + resources=None, + tools=None, + experimental=None, + prompts=None, + ), + server_info=Implementation(name="mock-server", version="0.1.0"), + instructions="The server instructions.", + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + session_notification = await client_to_server_receive.receive() + jsonrpc_notification = session_notification.message + assert isinstance(jsonrpc_notification, JSONRPCNotification) + initialized_notification = client_notification_adapter.validate_python( + jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + + # Create a message handler to catch exceptions + async def message_handler( # pragma: no cover + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + result = await session.initialize(protocol_version="2024-11-05") + + # Assert the result + assert isinstance(result, InitializeResult) + assert result.protocol_version == "2024-11-05" + assert isinstance(result.capabilities, ServerCapabilities) + assert result.server_info == Implementation(name="mock-server", version="0.1.0") + assert result.instructions == "The server instructions." + + # Check that the client sent the initialized notification + assert initialized_notification + assert isinstance(initialized_notification, InitializedNotification) + + @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index dae076616..6b1dcdd9b 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -402,3 +402,37 @@ async def test_client_session_group_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session + + +@pytest.mark.anyio +async def test_client_session_group_establish_session_custom_protocol_version(): + with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: + with mock.patch("mcp.client.session_group.mcp.stdio_client") as mock_stdio_client: + mock_client_cm_instance = mock.AsyncMock(name="stdioClientCM") + mock_read_stream = mock.AsyncMock(name="stdioRead") + mock_write_stream = mock.AsyncMock(name="stdioWrite") + + mock_client_cm_instance.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) + mock_stdio_client.return_value = mock_client_cm_instance + + mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") + mock_ClientSession_class.return_value = mock_raw_session_cm + + mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance") + mock_raw_session_cm.__aenter__.return_value = mock_entered_session + mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_initialize_result = mock.AsyncMock(name="InitializeResult") + mock_initialize_result.server_info = types.Implementation(name="foo", version="1") + mock_entered_session.initialize.return_value = mock_initialize_result + + group = ClientSessionGroup() + server_params = StdioServerParameters(command="test_stdio_cmd") + session_params = ClientSessionParameters(protocol_version="2024-11-05") + + async with contextlib.AsyncExitStack() as stack: + group._exit_stack = stack + await group._establish_session(server_params, session_params) + + mock_entered_session.initialize.assert_awaited_once_with(protocol_version="2024-11-05")