Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/mcp/client/_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _parse_supported(data: Any) -> list[str] | None:
return None


async def negotiate_auto(session: ClientSession) -> None:
async def negotiate_auto(session: ClientSession, protocol_version: str | None = None) -> None:
"""Drive the ``mode='auto'`` connect-time policy on ``session``.

Probes ``server/discover`` once (twice if the server names a mutual
Expand All @@ -65,14 +65,20 @@ async def negotiate_auto(session: ClientSession) -> None:
continue
if supported is not None and not any(v in HANDSHAKE_PROTOCOL_VERSIONS for v in supported):
raise # server is modern-only and disjoint — real incompatibility
await session.initialize() # every other rpc-error → legacy (the denylist)
if protocol_version is not None:
await session.initialize(protocol_version=protocol_version)
else:
await session.initialize() # every other rpc-error → legacy (the denylist)
return
# any other exception (httpx.TransportError, ConnectionError, anyio errors,
# RuntimeError from adopt) → propagate
try:
result = types.DiscoverResult.model_validate(raw)
except ValidationError:
await session.initialize() # unparseable result → not modern evidence
if protocol_version is not None:
await session.initialize(protocol_version=protocol_version)
else:
await session.initialize() # unparseable result → not modern evidence
return
session.adopt(result)
return
Expand Down
9 changes: 7 additions & 2 deletions src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ async def main():
"""Callback for handling elicitation requests."""

_entered: bool = field(init=False, default=False)
protocol_version_override: str | None = None
"""The protocol version to request during initialization. Defaults to the latest version."""
_session: ClientSession | None = field(init=False, default=None)
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
_connect: _Connector = field(init=False, repr=False, compare=False)
Expand Down Expand Up @@ -250,9 +252,12 @@ async def __aenter__(self) -> Client:
session = await exit_stack.enter_async_context(session)

if self.mode == "legacy":
await session.initialize()
if self.protocol_version_override is not None:
await session.initialize(protocol_version=self.protocol_version_override)
else:
await session.initialize()
elif self.mode == "auto":
await negotiate_auto(session)
await negotiate_auto(session, protocol_version=self.protocol_version_override)
else:
session.adopt(self.prior_discover or _synthesize_discover(self.mode))

Expand Down
4 changes: 2 additions & 2 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,13 @@ def _build_capabilities(self) -> types.ClientCapabilities:
)
return types.ClientCapabilities(sampling=sampling, elicitation=elicitation, experimental=None, roots=roots)

async def initialize(self) -> types.InitializeResult:
async def initialize(self, protocol_version: str = LATEST_HANDSHAKE_VERSION) -> types.InitializeResult:
if self._initialize_result is not None:
return self._initialize_result
result = await self.send_request(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocol_version=LATEST_HANDSHAKE_VERSION,
protocol_version=protocol_version,
capabilities=self._build_capabilities(),
client_info=self._client_info,
),
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ClientSessionParameters:
logging_callback: LoggingFnT | None = None
message_handler: MessageHandlerFnT | None = None
client_info: types.Implementation | None = None
protocol_version: str | None = None


class ClientSessionGroup:
Expand Down Expand Up @@ -352,7 +353,10 @@ async def _establish_session(
)
)

result = await session.initialize()
if session_params.protocol_version is not None:
result = await session.initialize(protocol_version=session_params.protocol_version)
else:
result = await session.initialize()

# Session successfully initialized.
# Store its stack and register the stack with the main group stack.
Expand Down
7 changes: 7 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ async def test_client_exposes_negotiated_protocol_version(app: MCPServer):
assert client.protocol_version == LATEST_HANDSHAKE_VERSION


async def test_client_custom_protocol_version(app: MCPServer):
"""Test that the client negotiates a custom protocol version when configured."""
async with Client(app, mode="legacy", protocol_version_override="2024-11-05") as client:
assert client.protocol_version == "2024-11-05"
assert client.server_info.name == "test"


async def test_client_with_simple_server(simple_server: Server):
"""Test that from_server works with a basic Server instance."""
async with Client(simple_server) as client:
Expand Down
25 changes: 22 additions & 3 deletions tests/client/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, *script: dict[str, Any] | Exception) -> None:
self._script: list[dict[str, Any] | Exception] = list(script)
self.probed_at: list[str] = []
self.initialized: bool = False
self.initialize_version: str | None = None
self.adopted: types.DiscoverResult | None = None

async def send_discover(self, version: str) -> dict[str, Any]:
Expand All @@ -61,16 +62,17 @@ async def send_discover(self, version: str) -> dict[str, Any]:
raise step
return step

async def initialize(self) -> None:
async def initialize(self, protocol_version: str | None = None) -> None:
self.initialized = True
self.initialize_version = protocol_version

def adopt(self, result: types.DiscoverResult) -> None:
self.adopted = result


async def _negotiate(session: _StubSession) -> None:
async def _negotiate(session: _StubSession, protocol_version: str | None = None) -> None:
"""Drive `negotiate_auto` against the stub; cast at one seam so the tests stay suppression-free."""
await negotiate_auto(cast("ClientSession", session))
await negotiate_auto(cast("ClientSession", session), protocol_version=protocol_version)


def _discover_dict(versions: list[str] | None = None) -> dict[str, Any]:
Expand Down Expand Up @@ -240,3 +242,20 @@ def test_parse_supported_returns_none_for_anything_not_shaped_like_the_spec_erro
"""`_parse_supported` returns the `supported` list when `error.data` validates as
`UnsupportedProtocolVersionErrorData`, and `None` otherwise — never raises."""
assert _parse_supported(data) == expected


async def test_negotiate_auto_mcp_error_with_custom_protocol_version() -> None:
"""Test that negotiate_auto initializes with a custom protocol version when discover returns an MCPError."""
session = _StubSession(MCPError(code=METHOD_NOT_FOUND, message="nope"))
await _negotiate(session, protocol_version="2024-11-05")
assert session.initialized
assert session.initialize_version == "2024-11-05"


async def test_negotiate_auto_validation_error_with_custom_protocol_version() -> None:
"""Test that negotiate_auto initializes with a custom protocol version when discover returns unparseable result."""
session = _StubSession({"not": "a discover result"})
await _negotiate(session, protocol_version="2024-11-05")
assert session.initialized
assert session.initialize_version == "2024-11-05"

84 changes: 84 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,90 @@ async def message_handler( # pragma: no cover
assert isinstance(initialized_notification, InitializedNotification)


@pytest.mark.anyio
async def test_client_session_initialize_custom_protocol_version():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

initialized_notification = None
result = None

async def mock_server():
nonlocal initialized_notification

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
assert request.params.protocol_version == "2024-11-05"

result = InitializeResult(
protocol_version="2024-11-05",
capabilities=ServerCapabilities(
logging=None,
resources=None,
tools=None,
experimental=None,
prompts=None,
),
server_info=Implementation(name="mock-server", version="0.1.0"),
instructions="The server instructions.",
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
session_notification = await client_to_server_receive.receive()
jsonrpc_notification = session_notification.message
assert isinstance(jsonrpc_notification, JSONRPCNotification)
initialized_notification = client_notification_adapter.validate_python(
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
)

# Create a message handler to catch exceptions
async def message_handler( # pragma: no cover
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
result = await session.initialize(protocol_version="2024-11-05")

# Assert the result
assert isinstance(result, InitializeResult)
assert result.protocol_version == "2024-11-05"
assert isinstance(result.capabilities, ServerCapabilities)
assert result.server_info == Implementation(name="mock-server", version="0.1.0")
assert result.instructions == "The server instructions."

# Check that the client sent the initialized notification
assert initialized_notification
assert isinstance(initialized_notification, InitializedNotification)


@pytest.mark.anyio
async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
Expand Down
34 changes: 34 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,37 @@ async def test_client_session_group_establish_session_parameterized(
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.server_info
assert returned_session is mock_entered_session


@pytest.mark.anyio
async def test_client_session_group_establish_session_custom_protocol_version():
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
with mock.patch("mcp.client.session_group.mcp.stdio_client") as mock_stdio_client:
mock_client_cm_instance = mock.AsyncMock(name="stdioClientCM")
mock_read_stream = mock.AsyncMock(name="stdioRead")
mock_write_stream = mock.AsyncMock(name="stdioWrite")

mock_client_cm_instance.__aenter__.return_value = (mock_read_stream, mock_write_stream)
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
mock_stdio_client.return_value = mock_client_cm_instance

mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
mock_ClientSession_class.return_value = mock_raw_session_cm

mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)

mock_initialize_result = mock.AsyncMock(name="InitializeResult")
mock_initialize_result.server_info = types.Implementation(name="foo", version="1")
mock_entered_session.initialize.return_value = mock_initialize_result

group = ClientSessionGroup()
server_params = StdioServerParameters(command="test_stdio_cmd")
session_params = ClientSessionParameters(protocol_version="2024-11-05")

async with contextlib.AsyncExitStack() as stack:
group._exit_stack = stack
await group._establish_session(server_params, session_params)

mock_entered_session.initialize.assert_awaited_once_with(protocol_version="2024-11-05")
Loading