diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 2d2acf930..4a57d5aee 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -20,7 +20,9 @@ json-schema-ref-no-deref - Connect, list tools (no $ref deref) request-metadata - Connect with all callbacks; client stamps _meta http-standard-headers - Connect, call a tool (Mcp-* headers checked) + http-invalid-tool-headers - List tools, call every surfaced tool (x-mcp-header filter) elicitation-sep1034-client-defaults - Elicitation with default accept callback + sep-2322-client-request-state - Drive the manual MRTR retry surface auth/client-credentials-jwt - Client credentials with private_key_jwt auth/client-credentials-basic - Client credentials with client_secret_basic auth/* - Authorization code flow (default for auth scenarios) @@ -296,6 +298,43 @@ async def run_http_standard_headers(server_url: str) -> None: logger.debug(f"add_numbers result: {result}") +def _stub_required_args(input_schema: dict[str, Any]) -> dict[str, Any]: + """Minimal arguments satisfying a tool inputSchema's required list.""" + by_type: dict[str, Any] = { + "string": "x", + "integer": 0, + "number": 0, + "boolean": False, + "object": {}, + "array": [], + "null": None, + } + properties = input_schema.get("properties", {}) + return {name: by_type.get(properties.get(name, {}).get("type"), "x") for name in input_schema.get("required", [])} + + +@register("http-invalid-tool-headers") +async def run_http_invalid_tool_headers(server_url: str) -> None: + """List tools, then call every tool the SDK surfaces (SEP-2243). + + The harness mock advertises one valid tool plus several with malformed + x-mcp-header annotations (empty, non-primitive type, duplicate, invalid + chars). The scenario passes if valid_tool is called and the malformed + ones are not -- so a conforming client filters them out of the list_tools + result and the loop below never sees them. The scenario sets + allowClientError, so a per-call failure is logged and skipped rather + than aborting the whole run. + """ + async with Client(server_url, mode=client_mode()) as client: + listed = await client.list_tools() + logger.debug(f"Surfaced tools: {[t.name for t in listed.tools]}") + for tool in listed.tools: + try: + await client.call_tool(tool.name, _stub_required_args(tool.input_schema)) + except Exception: + logger.exception(f"call_tool({tool.name!r}) failed") + + @register("elicitation-sep1034-client-defaults") async def run_elicitation_defaults(server_url: str) -> None: """Connect with elicitation callback that applies schema defaults.""" @@ -305,6 +344,53 @@ async def run_elicitation_defaults(server_url: str) -> None: logger.debug(f"test_client_elicitation_defaults result: {result}") +@register("sep-2322-client-request-state") +async def run_mrtr_client(server_url: str) -> None: + """Drive the manual MRTR retry surface against the SEP-2322 client mock. + + The mock speaks the modern lifecycle (server/discover, no initialize) and + inspects the wire params of each tools/call round, so this exercises the + explicit allow_input_required=True path rather than an auto-loop: round 1 + receives an InputRequiredResult, the fixture fulfils the elicitation + locally, then round 2 retries with input_responses + the echoed + request_state. Passing request_state straight off the typed result -- a + str when the server sent one, None when it didn't -- lets the + serializer's exclude_none drop the key in the no-state case without a + branch here. The unrelated call between rounds proves MRTR params don't + leak across tools, and the no-result-type call must parse as a complete + CallToolResult with no retry. + """ + async with Client(server_url, mode=client_mode()) as client: + await client.list_tools() + confirm = {"confirm": types.ElicitResult(action="accept", content={"confirmed": True})} + + r1 = await client.call_tool("test_mrtr_echo_state", {}, allow_input_required=True) + assert isinstance(r1, types.InputRequiredResult) + + await client.call_tool("test_mrtr_unrelated", {}) + + await client.call_tool( + "test_mrtr_echo_state", + {}, + input_responses=confirm, + request_state=r1.request_state, + allow_input_required=True, + ) + + r2 = await client.call_tool("test_mrtr_no_state", {}, allow_input_required=True) + assert isinstance(r2, types.InputRequiredResult) + await client.call_tool( + "test_mrtr_no_state", + {}, + input_responses=confirm, + request_state=r2.request_state, + allow_input_required=True, + ) + + result = await client.call_tool("test_mrtr_no_result_type", {}) + assert isinstance(result, types.CallToolResult) + + @register("auth/client-credentials-jwt") async def run_client_credentials_jwt(server_url: str) -> None: """Client credentials flow with private_key_jwt authentication.""" @@ -441,8 +527,7 @@ def main() -> None: asyncio.run(run_auth_code_client(server_url)) else: # Unhandled scenarios: - # - sep-2322-client-request-state (SEP-2322 / S6: MRTR client loop) - # - http-custom-headers, http-invalid-tool-headers (SEP-2243 / S8: Mcp-Param-* headers) + # - http-custom-headers (SEP-2243 / S8: Mcp-Param-* emission) print(f"Unknown scenario: {scenario}", file=sys.stderr) sys.exit(1) else: diff --git a/.github/actions/conformance/expected-failures.2026-07-28.yml b/.github/actions/conformance/expected-failures.2026-07-28.yml index 529eb8bab..a4b4f4480 100644 --- a/.github/actions/conformance/expected-failures.2026-07-28.yml +++ b/.github/actions/conformance/expected-failures.2026-07-28.yml @@ -21,48 +21,19 @@ # milestone. client: - # --- Same gaps as the 2025 baseline (fail identically when forced to 2026-07-28) --- - # SEP-2322 (multi-round-trip requests): client does not echo requestState / - # handle IncompleteResult yet. - - sep-2322-client-request-state - # SEP-2243 (HTTP standardization): no fixture handler / client Mcp-Param-* support yet. + # SEP-2243 (HTTP standardization): no client Mcp-Param-* support yet — needs the + # tool-schema-cache vs per-call tool_definition design (S8). - http-custom-headers - - http-invalid-tool-headers # auth/enterprise-managed-authorization (SEP-990) is in the 2025 baseline but # NOT here: the harness skips it as inapplicable at --spec-version 2026-07-28 # (it is an extension scenario not carried into the 2026 wire), so it is # neither run nor evaluated on this leg. server: - # --- Carried-forward 2025-era scenarios still failing on the 2026 wire --- # The stateless 2026 path now reaches handlers for plain request/response # scenarios; tools-call-with-progress still fails because the stateless # server has no channel for server→client progress notifications. - tools-call-with-progress - # SEP-2106 (JSON Schema 2020-12 in tool inputSchema): the fixture tool's - # schema has none of the 2020-12 keywords the scenario checks. The scenario - # is in `--suite all` but not `--suite active`, so this is the only leg that - # runs it; it fails identically at 2025-11-25 (not a 2026-path regression). - - json-schema-2020-12 - - # --- Draft scenarios (same failures and reasons as the `--suite draft` leg) --- - # SEP-2322 (multi-round-trip requests / IncompleteResult): not implemented. - - input-required-result-basic-elicitation - - input-required-result-basic-sampling - - input-required-result-basic-list-roots - - input-required-result-request-state - - input-required-result-multiple-input-requests - - input-required-result-multi-round + # SEP-2322 (multi-round-trip requests / IncompleteResult): the prompt pipeline + # cannot return InputRequiredResult from MCPServer yet (tools/call can). - input-required-result-non-tool-request - - input-required-result-result-type - - input-required-result-tampered-state - - input-required-result-capability-check - # SEP-2243 (HTTP header standardization): Mcp-Method / Mcp-Name cross-check - # against the request body is not implemented. - - http-header-validation - # WARNING-only entries: these scenarios emit no FAILURE checks but the - # expected-failures evaluator counts WARNINGs as failures (the summary line - # only shows passed/failed, not warnings, so a local re-probe can mis-read - # these as stale). - - input-required-result-missing-input-response - - input-required-result-validate-input diff --git a/.github/actions/conformance/expected-failures.yml b/.github/actions/conformance/expected-failures.yml index 2a411b4cd..cb59dba02 100644 --- a/.github/actions/conformance/expected-failures.yml +++ b/.github/actions/conformance/expected-failures.yml @@ -12,12 +12,9 @@ client: # --- Draft-spec scenarios (in `--suite draft`, also part of `--suite all`) --- - # SEP-2322 (multi-round-trip requests): client does not echo requestState / - # handle IncompleteResult yet. - - sep-2322-client-request-state - # SEP-2243 (HTTP standardization): no fixture handler / client Mcp-Param-* support yet. + # SEP-2243 (HTTP standardization): no client Mcp-Param-* support yet — needs the + # tool-schema-cache vs per-call tool_definition design (S8). - http-custom-headers - - http-invalid-tool-headers # --- Pre-existing scenarios that fail on checks added after conformance 0.1.15 --- # SEP-990 (enterprise-managed authorization extension): no fixture handler / @@ -26,23 +23,6 @@ client: server: # --- Draft-spec scenarios (in `--suite draft`; the `active` suite is green) --- - # SEP-2322 (multi-round-trip requests / IncompleteResult): not implemented. - - input-required-result-basic-elicitation - - input-required-result-basic-sampling - - input-required-result-basic-list-roots - - input-required-result-request-state - - input-required-result-multiple-input-requests - - input-required-result-multi-round + # SEP-2322 (multi-round-trip requests / IncompleteResult): the prompt pipeline + # cannot return InputRequiredResult from MCPServer yet (tools/call can). - input-required-result-non-tool-request - - input-required-result-result-type - - input-required-result-tampered-state - - input-required-result-capability-check - # SEP-2243 (HTTP header standardization): Mcp-Method / Mcp-Name cross-check - # against the request body is not implemented. - - http-header-validation - # WARNING-only entries: these scenarios emit no FAILURE checks but the - # expected-failures evaluator counts WARNINGs as failures (the summary line - # only shows passed/failed, not warnings, so a local re-probe can mis-read - # these as stale). - - input-required-result-missing-input-response - - input-required-result-validate-input diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index e985a52f6..9f5ce489f 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -15,15 +15,10 @@ permissions: env: # Pinned conformance harness package spec (passed verbatim to `npx --yes`). - # Use a published version, e.g. @modelcontextprotocol/conformance@0.2.0-alpha.5. + # Use a published version, e.g. @modelcontextprotocol/conformance@0.2.0-alpha.7. # Bump deliberately and reconcile both # .github/actions/conformance/expected-failures*.yml files in the same change. - # - # TODO: replace with @modelcontextprotocol/conformance@0.2.0-alpha.5 once - # https://github.com/modelcontextprotocol/conformance/pull/357 publishes, and - # drop CONFORMANCE_PKG_SHA256 plus the fetch-and-verify step below. - CONFORMANCE_PKG: "https://pkg.pr.new/@modelcontextprotocol/conformance@65fcd39" - CONFORMANCE_PKG_SHA256: "9a381d7083f8be2fe7ae44efeca54530f18c61425805ddaf9cd88915efcc1574" + CONFORMANCE_PKG: "@modelcontextprotocol/conformance@0.2.0-alpha.7" jobs: server-conformance: @@ -39,19 +34,6 @@ jobs: - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 with: node-version: 24 - - name: Fetch and verify conformance harness - # Only when CONFORMANCE_PKG is a URL: download, check the recorded - # sha256, and re-point CONFORMANCE_PKG at the verified local tarball. - # When CONFORMANCE_PKG is a registry spec, this step is a no-op (npm's - # own integrity check applies). - run: | - case "$CONFORMANCE_PKG" in - https://*) - curl -fsSL "$CONFORMANCE_PKG" -o /tmp/conformance.tgz - echo "$CONFORMANCE_PKG_SHA256 /tmp/conformance.tgz" | sha256sum -c - - echo "CONFORMANCE_PKG=file:/tmp/conformance.tgz" >> "$GITHUB_ENV" - ;; - esac - run: uv sync --frozen --all-extras --package mcp-everything-server - name: Run server conformance (active suite) run: >- @@ -83,26 +65,22 @@ jobs: - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 with: node-version: 24 - - name: Fetch and verify conformance harness - run: | - case "$CONFORMANCE_PKG" in - https://*) - curl -fsSL "$CONFORMANCE_PKG" -o /tmp/conformance.tgz - echo "$CONFORMANCE_PKG_SHA256 /tmp/conformance.tgz" | sha256sum -c - - echo "CONFORMANCE_PKG=file:/tmp/conformance.tgz" >> "$GITHUB_ENV" - ;; - esac - run: uv sync --frozen --all-extras --package mcp - name: Run client conformance (all suite) + # The harness runs all scenarios via unbounded Promise.all; with 40 + # scenarios on a 2-core runner the slowest one (sse-retry, which has a + # real-time SSE reconnect wait) needs more than the 30s default budget. run: >- npx --yes "$CONFORMANCE_PKG" client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all + --timeout 60000 --expected-failures ./.github/actions/conformance/expected-failures.yml - name: Run client conformance (2026-07-28 wire, all suite) run: >- npx --yes "$CONFORMANCE_PKG" client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all + --timeout 60000 --spec-version 2026-07-28 --expected-failures ./.github/actions/conformance/expected-failures.2026-07-28.yml diff --git a/docs/migration.md b/docs/migration.md index e977ce4a2..7598b5202 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -10,7 +10,8 @@ Version 2 of the MCP Python SDK introduces several breaking changes to improve t ### `MCPServer.call_tool()` returns `CallToolResult` -`MCPServer.call_tool()` now always returns a `CallToolResult`. It previously +`MCPServer.call_tool()` now returns a `CallToolResult` (or an +`InputRequiredResult` when a multi-round tool requests further input). It previously advertised `Sequence[ContentBlock] | dict[str, Any]` and leaked the internal conversion shapes (a bare content sequence or a `(content, structured_content)` tuple), forcing callers to re-assemble a `CallToolResult` themselves. diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index c43b6735c..f622aac7a 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -6,8 +6,12 @@ import asyncio import base64 +import binascii +import hashlib +import hmac import json import logging +from typing import Any import click from mcp.server import ServerRequestContext @@ -20,10 +24,20 @@ Completion, CompletionArgument, CompletionContext, + CreateMessageRequest, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequest, + ElicitRequestFormParams, + ElicitResult, EmbeddedResource, EmptyResult, ImageContent, + InputRequest, + InputRequiredResult, JSONRPCMessage, + ListRootsRequest, + ListRootsResult, PromptReference, ResourceTemplateReference, SamplingMessage, @@ -33,7 +47,7 @@ TextResourceContents, UnsubscribeRequestParams, ) -from mcp_types.jsonrpc import MISSING_REQUIRED_CLIENT_CAPABILITY +from mcp_types.jsonrpc import INVALID_PARAMS, MISSING_REQUIRED_CLIENT_CAPABILITY from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -333,6 +347,231 @@ async def test_missing_capability(ctx: Context) -> str: return "Client declared sampling capability; proceeding." +# SEP-2322 InputRequiredResult fixtures (multi-round-trip / ephemeral workflow) + +NAME_SCHEMA = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} + + +def _name_elicitation(message: str = "What is your name?") -> ElicitRequest: + return ElicitRequest(params=ElicitRequestFormParams(message=message, requested_schema=NAME_SCHEMA)) + + +@mcp.tool() +async def test_input_required_result_elicitation(ctx: Context) -> str | InputRequiredResult: + """Tests InputRequiredResult with a single elicitation request""" + responses = ctx.input_responses + if responses and "user_name" in responses: + answer = responses["user_name"] + name = answer.content.get("name", "stranger") if isinstance(answer, ElicitResult) and answer.content else "?" + return f"Hello, {name}!" + return InputRequiredResult(input_requests={"user_name": _name_elicitation()}) + + +@mcp.tool() +async def test_input_required_result_sampling(ctx: Context) -> str | InputRequiredResult: + """Tests InputRequiredResult with a single sampling request""" + responses = ctx.input_responses + if responses and "capital_question" in responses: + answer = responses["capital_question"] + text = answer.content.text if isinstance(answer, CreateMessageResult) and answer.content.type == "text" else "?" + return f"Model said: {text}" + return InputRequiredResult( + input_requests={ + "capital_question": CreateMessageRequest( + params=CreateMessageRequestParams( + messages=[ + SamplingMessage( + role="user", content=TextContent(type="text", text="What is the capital of France?") + ) + ], + max_tokens=100, + ) + ) + } + ) + + +@mcp.tool() +async def test_input_required_result_list_roots(ctx: Context) -> str | InputRequiredResult: + """Tests InputRequiredResult with a single roots/list request""" + responses = ctx.input_responses + if responses and "client_roots" in responses: + answer = responses["client_roots"] + count = len(answer.roots) if isinstance(answer, ListRootsResult) else 0 + return f"Client exposed {count} root(s)." + return InputRequiredResult(input_requests={"client_roots": ListRootsRequest()}) + + +@mcp.tool() +async def test_input_required_result_request_state(ctx: Context) -> str | InputRequiredResult: + """Tests requestState round-tripping in the InputRequiredResult flow""" + responses = ctx.input_responses + if responses and "confirm" in responses and ctx.request_state == "request-state-nonce": + return "state-ok: confirmation received" + confirm = ElicitRequest( + params=ElicitRequestFormParams( + message="Please confirm", + requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}, "required": ["ok"]}, + ) + ) + return InputRequiredResult(input_requests={"confirm": confirm}, request_state="request-state-nonce") + + +@mcp.tool() +async def test_input_required_result_multiple_inputs(ctx: Context) -> str | InputRequiredResult: + """Tests InputRequiredResult carrying elicitation, sampling and roots requests together""" + responses = ctx.input_responses + if responses and {"user_name", "greeting", "client_roots"} <= responses.keys(): + return "All inputs received." + return InputRequiredResult( + input_requests={ + "user_name": _name_elicitation(), + "greeting": CreateMessageRequest( + params=CreateMessageRequestParams( + messages=[ + SamplingMessage(role="user", content=TextContent(type="text", text="Generate a greeting")) + ], + max_tokens=50, + ) + ), + "client_roots": ListRootsRequest(), + }, + request_state="multiple-inputs", + ) + + +@mcp.tool() +async def test_input_required_result_multi_round(ctx: Context) -> str | InputRequiredResult: + """Tests a three-round InputRequiredResult flow with evolving requestState""" + state = json.loads(ctx.request_state) if ctx.request_state else {"round": 0} + responses = ctx.input_responses or {} + + if state["round"] == 0: + return InputRequiredResult( + input_requests={"step1": _name_elicitation("Step 1: What is your name?")}, + request_state=json.dumps({"round": 1}), + ) + + if state["round"] == 1 and "step1" in responses: + step1 = responses["step1"] + name = step1.content.get("name") if isinstance(step1, ElicitResult) and step1.content else None + color_schema = {"type": "object", "properties": {"color": {"type": "string"}}, "required": ["color"]} + return InputRequiredResult( + input_requests={ + "step2": ElicitRequest( + params=ElicitRequestFormParams( + message="Step 2: What is your favorite color?", requested_schema=color_schema + ) + ) + }, + request_state=json.dumps({"round": 2, "name": name}), + ) + + if state["round"] == 2 and "step2" in responses: + step2 = responses["step2"] + color = step2.content.get("color") if isinstance(step2, ElicitResult) and step2.content else None + return f"{state.get('name')} likes {color}." + + # Missing or out-of-order response: re-request from the start. + return InputRequiredResult( + input_requests={"step1": _name_elicitation("Step 1: What is your name?")}, + request_state=json.dumps({"round": 1}), + ) + + +# Fixed key for the conformance fixture; a real server would derive or rotate this. +_STATE_HMAC_KEY = b"everything-server-fixture-key" + + +def _seal_state(payload: str) -> str: + encoded = base64.urlsafe_b64encode(payload.encode()).decode() + sig = hmac.new(_STATE_HMAC_KEY, encoded.encode(), hashlib.sha256).hexdigest() + return f"{encoded}.{sig}" + + +def _unseal_state(state: str) -> str: + encoded, _, sig = state.partition(".") + expected = hmac.new(_STATE_HMAC_KEY, encoded.encode(), hashlib.sha256).hexdigest() + if not sig or not hmac.compare_digest(sig, expected): + raise MCPError(code=INVALID_PARAMS, message="requestState failed integrity verification") + try: + return base64.urlsafe_b64decode(encoded).decode() + except (binascii.Error, UnicodeDecodeError) as e: + raise MCPError(code=INVALID_PARAMS, message="requestState failed integrity verification") from e + + +@mcp.tool() +async def test_input_required_result_tampered_state(ctx: Context) -> str | InputRequiredResult: + """Tests that the server rejects a requestState that fails HMAC verification""" + if ctx.request_state is None: + confirm = ElicitRequest( + params=ElicitRequestFormParams( + message="Please confirm", + requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}, "required": ["ok"]}, + ) + ) + return InputRequiredResult(input_requests={"confirm": confirm}, request_state=_seal_state("round-1")) + payload = _unseal_state(ctx.request_state) + return f"state-ok: {payload}" + + +@mcp.tool() +async def test_input_required_result_capabilities(ctx: Context) -> InputRequiredResult: + """Tests that inputRequests only include methods the client declared support for""" + caps = ctx.client_capabilities + requests: dict[str, InputRequest] = {} + if caps is None or caps.sampling is not None: + requests["sample"] = CreateMessageRequest( + params=CreateMessageRequestParams( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello"))], + max_tokens=50, + ) + ) + if caps is None or caps.elicitation is not None: + requests["ask"] = _name_elicitation() + return InputRequiredResult(input_requests=requests, request_state="capability-gated") + + +# SEP-1613 / SEP-2106 JSON Schema 2020-12 fixture: a tool whose inputSchema carries +# the full set of 2020-12 keywords the conformance scenario asserts on. + +JSON_SCHEMA_2020_12_INPUT_SCHEMA: dict[str, Any] = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": { + "address": { + "$anchor": "addressDef", + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + } + }, + "properties": { + "name": {"type": "string"}, + "address": {"$ref": "#/$defs/address"}, + "contactMethod": {"type": "string", "enum": ["phone", "email"]}, + "phone": {"type": "string"}, + "email": {"type": "string"}, + }, + "allOf": [{"anyOf": [{"required": ["phone"]}, {"required": ["email"]}]}], + "if": {"properties": {"contactMethod": {"const": "phone"}}, "required": ["contactMethod"]}, + "then": {"required": ["phone"]}, + "else": {"required": ["email"]}, + "additionalProperties": False, +} + + +@mcp.tool(name="json_schema_2020_12_tool") +def json_schema_2020_12_tool() -> str: + """Tests JSON Schema 2020-12 keyword preservation in tools/list (inputSchema installed below).""" + return "json_schema_2020_12_tool" + + +# TODO(felix): replace with a public input_schema= override once MCPServer.tool() grows one. +mcp._tool_manager._tools["json_schema_2020_12_tool"].parameters = ( # pyright: ignore[reportPrivateUsage] + JSON_SCHEMA_2020_12_INPUT_SCHEMA +) + + @mcp.tool() async def test_reconnection(ctx: Context) -> str: """Tests SSE polling by closing stream mid-call (SEP-1699)""" diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 591223250..0c6e0270c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -38,7 +38,9 @@ MCP_METHOD_HEADER, MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER, + NAME_BEARING_METHODS, encode_header_value, + find_invalid_x_mcp_header, ) from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.message import ClientMessageMetadata, SessionMessage @@ -78,8 +80,8 @@ def stamp(data: dict[str, Any], opts: CallOptions) -> None: headers = opts.setdefault("headers", {}) headers[MCP_PROTOCOL_VERSION_HEADER] = protocol_version headers[MCP_METHOD_HEADER] = data["method"] - # TODO: also emit Mcp-Name for prompts/get (params.name) and resources/read (params.uri) - if data["method"] == "tools/call" and isinstance(name := params.get("name"), str): + name_key = NAME_BEARING_METHODS.get(data["method"]) + if name_key is not None and isinstance(name := params.get(name_key), str): headers[MCP_NAME_HEADER] = encode_header_value(name) return stamp @@ -429,7 +431,7 @@ async def send_discover(self, version: str) -> dict[str, Any]: opts: CallOptions = { "timeout": DISCOVER_TIMEOUT_SECONDS, "cancel_on_abandon": False, - "headers": {MCP_PROTOCOL_VERSION_HEADER: version}, + "headers": {MCP_PROTOCOL_VERSION_HEADER: version, MCP_METHOD_HEADER: data["method"]}, } return await self._dispatcher.send_raw_request(data["method"], data.get("params"), opts) @@ -759,6 +761,16 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None types.ListToolsResult, ) + if self._negotiated_version in MODERN_PROTOCOL_VERSIONS: + # 2026-07-28: clients MUST drop tools whose x-mcp-header annotations are invalid. + kept: list[types.Tool] = [] + for tool in result.tools: + if (reason := find_invalid_x_mcp_header(tool.input_schema)) is not None: + logger.warning("dropping tool %r: invalid x-mcp-header (%s)", tool.name, reason) + continue + kept.append(tool) + result.tools = kept + # Cache tool output schemas for future validation # Note: don't clear the cache, as we may be using a cursor for tool in result.tools: diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index 9a42d64dd..cecf21f08 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -41,7 +41,11 @@ from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError -from mcp.shared.inbound import ERROR_CODE_HTTP_STATUS, InboundLadderRejection, classify_inbound_request +from mcp.shared.inbound import ( + ERROR_CODE_HTTP_STATUS, + InboundLadderRejection, + classify_inbound_request, +) from mcp.shared.jsonrpc_dispatcher import handler_exception_to_error_data from mcp.shared.message import MessageMetadata, ServerMessageMetadata from mcp.shared.transport_context import TransportContext diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index f703e760f..aeb91fdfe 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Generic -from mcp_types import LoggingLevel +from mcp_types import ClientCapabilities, InputResponseRequestParams, InputResponses, LoggingLevel from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated @@ -58,6 +58,7 @@ async def my_tool(x: int, ctx: Context) -> str: _request_context: ServerRequestContext[LifespanContextT, RequestT] | None _mcp_server: MCPServer | None + _input_params: InputResponseRequestParams | None # TODO(maxisbey): Consider making request_context/mcp_server required, or refactor Context entirely. def __init__( @@ -65,12 +66,14 @@ def __init__( *, request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, mcp_server: MCPServer | None = None, + input_params: InputResponseRequestParams | None = None, # TODO(Marcelo): We should drop this kwargs parameter. **kwargs: Any, ): super().__init__(**kwargs) self._request_context = request_context self._mcp_server = mcp_server + self._input_params = input_params @property def mcp_server(self) -> MCPServer: @@ -219,6 +222,33 @@ def request_id(self) -> str: """Get the unique ID for this request.""" return str(self.request_context.request_id) + @property + def input_responses(self) -> InputResponses | None: + """Client responses to a prior `InputRequiredResult.input_requests`. + + `None` on the initial round, or when the client retried without + responses. + """ + return self._input_params.input_responses if self._input_params else None + + @property + def request_state(self) -> str | None: + """Opaque state echoed from a prior `InputRequiredResult.request_state`. + + `None` on the initial round. + """ + return self._input_params.request_state if self._input_params else None + + @property + def client_capabilities(self) -> ClientCapabilities | None: + """The client's declared capabilities for this connection. + + `None` when the client supplied no client info (e.g. an anonymous + stateless request without the reserved `_meta` keys). + """ + client_params = self.request_context.session.client_params + return client_params.capabilities if client_params else None + @property def session(self): """Access to the underlying session for advanced usage.""" diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 15308eefd..67c81c18a 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -24,6 +24,7 @@ GetPromptRequestParams, GetPromptResult, Icon, + InputRequiredResult, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -306,8 +307,8 @@ async def _handle_list_tools( async def _handle_call_tool( self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams - ) -> CallToolResult: - context = Context(request_context=ctx, mcp_server=self) + ) -> CallToolResult | InputRequiredResult: + context = Context(request_context=ctx, mcp_server=self, input_params=params) try: return await self.call_tool(params.name, params.arguments or {}, context) except MCPError: @@ -323,7 +324,7 @@ async def _handle_list_resources( async def _handle_read_resource( self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams ) -> ReadResourceResult: - context = Context(request_context=ctx, mcp_server=self) + context = Context(request_context=ctx, mcp_server=self, input_params=params) try: results = await self.read_resource(params.uri, context) except ResourceNotFoundError as err: @@ -365,7 +366,7 @@ async def _handle_list_prompts( async def _handle_get_prompt( self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams ) -> GetPromptResult: - context = Context(request_context=ctx, mcp_server=self) + context = Context(request_context=ctx, mcp_server=self, input_params=params) return await self.get_prompt(params.name, params.arguments, context) async def list_tools(self) -> list[MCPTool]: @@ -387,7 +388,7 @@ async def list_tools(self) -> list[MCPTool]: async def call_tool( self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None - ) -> CallToolResult: + ) -> CallToolResult | InputRequiredResult: """Call a tool by name with arguments.""" if context is None: context = Context(mcp_server=self) diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index d0c679c05..97eb3909e 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -4,12 +4,12 @@ from collections.abc import Awaitable, Callable, Sequence from itertools import chain from types import GenericAlias -from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints +from typing import Annotated, Any, Union, cast, get_args, get_origin, get_type_hints import anyio import anyio.to_thread import pydantic_core -from mcp_types import CallToolResult, ContentBlock, TextContent +from mcp_types import CallToolResult, ContentBlock, InputRequiredResult, TextContent from pydantic import BaseModel, ConfigDict, Field, PydanticUserError, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind @@ -29,6 +29,10 @@ logger = get_logger(__name__) +def _is_input_required_type(obj: Any) -> bool: + return isinstance(obj, type) and issubclass(obj, InputRequiredResult) + + class StrictJsonSchema(GenerateJsonSchema): """A JSON schema generator that raises exceptions instead of emitting warnings. @@ -88,9 +92,13 @@ async def call_fn_with_arg_validation( else: return await anyio.to_thread.run_sync(functools.partial(fn, **arguments_parsed_dict)) - def convert_result(self, result: Any) -> CallToolResult: + def convert_result(self, result: Any) -> CallToolResult | InputRequiredResult: """Convert a function call result into a `CallToolResult`. + An `InputRequiredResult` is passed through unchanged so the multi-round + flow surfaces on the wire as `resultType: "input_required"` rather than + being JSON-dumped into a text block. + Note: we build unstructured content here **even though the lowlevel server tool call handler provides generic backwards compatibility serialization of structured content**. This is for MCPServer backwards compatibility: we need to @@ -98,6 +106,8 @@ def convert_result(self, result: Any) -> CallToolResult: from function return values, whereas the lowlevel server simply serializes the structured output. """ + if isinstance(result, InputRequiredResult): + return result if isinstance(result, CallToolResult): if self.output_schema is not None: assert self.output_model is not None, "Output model must be set if output schema is defined" @@ -266,10 +276,33 @@ def func_metadata( # unknown (i.e. a bare `Final`). assert return_type_expr is not UNKNOWN + if _is_input_required_type(return_type_expr): + # A tool annotated to return only InputRequiredResult never produces structured content. + return FuncMetadata(arg_model=arguments_model) + + # The annotation fed to schema derivation. Starts as the raw return annotation (preserving any + # Annotated[...] wrapper) and is narrowed below if InputRequiredResult arms are stripped. + effective_annotation: Any = sig.return_annotation + if is_union_origin(get_origin(return_type_expr)): args = get_args(return_type_expr) - # Check if CallToolResult appears in the union (excluding None for Optional check) - if any(isinstance(arg, type) and issubclass(arg, CallToolResult) for arg in args if arg is not type(None)): + # InputRequiredResult is a control-flow signal, not data: strip it so the residual arms + # drive schema derivation. convert_result short-circuits on an InputRequiredResult instance + # before output validation, so the schema only ever sees the data arms at runtime. + residual = tuple(a for a in args if not _is_input_required_type(a)) + if not residual: + return FuncMetadata(arg_model=arguments_model) + if len(residual) != len(args): + # PEP 604 has no syntax for "union of a runtime tuple"; Union[...] is the only spelling. + effective_annotation = residual[0] if len(residual) == 1 else Union[residual] # noqa: UP007 + # Re-normalize so the residual is processed exactly as if it had been the declared + # return annotation: unwraps a top-level Annotated[...] arm and re-derives metadata, + # so the CallToolResult/BaseModel/TypedDict dispatch below sees the bare type. + inspected_return_ann = inspect_annotation(effective_annotation, annotation_source=AnnotationSource.FUNCTION) + return_type_expr = inspected_return_ann.type + if len(residual) > 1 and any( + isinstance(a, type) and issubclass(a, CallToolResult) for a in residual if a is not type(None) + ): raise InvalidSignature( f"Function {func.__name__}: CallToolResult cannot be used in Union or Optional types. " "To return empty results, use: CallToolResult(content=[])" @@ -295,7 +328,7 @@ def func_metadata( else: return FuncMetadata(arg_model=arguments_model) else: - original_annotation = sig.return_annotation + original_annotation = effective_annotation output_model, output_schema, wrap_output = _try_create_model_and_schema( original_annotation, return_type_expr, func.__name__ diff --git a/src/mcp/shared/inbound.py b/src/mcp/shared/inbound.py index f54f125e7..1c70e3d92 100644 --- a/src/mcp/shared/inbound.py +++ b/src/mcp/shared/inbound.py @@ -1,19 +1,23 @@ """Inbound request classification for the modern per-request-envelope path. -Pure module: no I/O, no transport, no ``mcp.server`` imports. Runs the +Pure module: no I/O, no transport, no `mcp.server` imports. Runs the validation ladder against a decoded JSON-RPC body and returns either an :class:`InboundModernRoute` (every rung passed) or an :class:`InboundLadderRejection` (the first rung that failed). Callers map a -rejection's ``code`` through :data:`ERROR_CODE_HTTP_STATUS` to pick the HTTP +rejection's `code` through :data:`ERROR_CODE_HTTP_STATUS` to pick the HTTP status. + +Also hosts the shared header-value codec and the `x-mcp-header` schema +validator so client emit and server validate read the same source of truth. """ import base64 +import binascii import re -from collections.abc import Mapping, Sequence +from collections.abc import Iterator, Mapping, Sequence from dataclasses import dataclass from types import MappingProxyType -from typing import Any, Final +from typing import Any, Final, cast from mcp_types import ( CLIENT_CAPABILITIES_META_KEY, @@ -39,8 +43,12 @@ "MCP_METHOD_HEADER", "MCP_NAME_HEADER", "MCP_PROTOCOL_VERSION_HEADER", + "NAME_BEARING_METHODS", + "X_MCP_HEADER_KEY", "classify_inbound_request", + "decode_header_value", "encode_header_value", + "find_invalid_x_mcp_header", ] MCP_PROTOCOL_VERSION_HEADER: Final = "mcp-protocol-version" @@ -52,17 +60,152 @@ MCP_NAME_HEADER: Final = "mcp-name" """Canonical lowercase name of the HTTP header carrying the resource name (tool/prompt/resource URI).""" -_B64_SENTINEL = re.compile(r"^=\?base64\?.*\?=$") +X_MCP_HEADER_KEY: Final = "x-mcp-header" +"""JSON-Schema property annotation that designates an `Mcp-Param-*` HTTP header.""" + +NAME_BEARING_METHODS: Final[Mapping[str, str]] = MappingProxyType( + { + "tools/call": "name", + "prompts/get": "name", + "resources/read": "uri", + } +) +"""Method → params key whose value is mirrored as the `Mcp-Name` HTTP header. + +Shared by client emit (which header to send) and server validate (which body +field to compare against), so both ends agree on the field by construction. +""" + +_B64_SENTINEL = re.compile(r"^=\?base64\?(?P.*)\?=$") # RFC 7230 token chars minus DEL; visible ASCII 0x20-0x7E is the practical bound for a header value. _HEADER_SAFE = re.compile(r"^[\x20-\x7E]*$") +# RFC 9110 §5.6.2 token: the only characters permitted in an HTTP field name. +_RFC9110_TOKEN = re.compile(r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$") +# JSON-Schema types the spec permits to carry `x-mcp-header` (transports.mdx +# §Custom Headers). `number` is explicitly forbidden — float→str is not +# portable across implementations. +_X_MCP_HEADER_PRIMITIVE_TYPES: Final = frozenset({"string", "integer", "boolean"}) + +# JSON Schema 2020-12 applicator keywords whose values are themselves schema +# positions, grouped by value shape. `properties` is handled separately as the +# only keyword that preserves the statically-reachable chain; every keyword +# here drops the chain to None. Instance-data keywords (`default`, `examples`, +# `const`, `enum`) and `$ref`/`$dynamicRef` are deliberately absent so the +# walk never mistakes data for an annotation and never dereferences. +_SUBSCHEMA_SINGLE: Final = frozenset( + { + "items", + "contains", + "unevaluatedItems", + "additionalProperties", + "propertyNames", + "unevaluatedProperties", + "not", + "if", + "then", + "else", + "contentSchema", + } +) +_SUBSCHEMA_LIST: Final = frozenset({"allOf", "anyOf", "oneOf", "prefixItems"}) +_SUBSCHEMA_MAP: Final = frozenset({"patternProperties", "dependentSchemas", "$defs", "definitions"}) + + +def _walk_schema_positions(root: Any) -> Iterator[tuple[tuple[str, ...] | None, dict[str, Any]]]: + """Yield `(properties_path, schema)` for every schema position in `root`. + + `properties_path` is the chain of `properties` keys from the root to the + position, or `None` once any other applicator keyword has been crossed. + The root itself yields `()`. Only the JSON Schema 2020-12 applicators + listed above are entered; instance-data keywords are not, and `$ref` is + not dereferenced, so the walk terminates on any finite JSON value. An + explicit stack keeps the function total even on pathologically deep input. + """ + stack: list[tuple[tuple[str, ...] | None, Any]] = [((), root)] + while stack: + path, node = stack.pop() + if not isinstance(node, dict): + continue + schema = cast(dict[str, Any], node) + yield path, schema + for kw, val in schema.items(): + if kw == "properties" and isinstance(val, dict): + for name, sub in cast(dict[str, Any], val).items(): + stack.append(((*path, name) if path is not None else None, sub)) + elif kw in _SUBSCHEMA_SINGLE: + stack.append((None, val)) + elif kw in _SUBSCHEMA_LIST and isinstance(val, list): + stack.extend((None, sub) for sub in cast(list[Any], val)) + elif kw in _SUBSCHEMA_MAP and isinstance(val, dict): + stack.extend((None, sub) for sub in cast(dict[str, Any], val).values()) def encode_header_value(value: str) -> str: + """Wrap `value` in the `=?base64?...?=` sentinel when it would not survive an HTTP field round-trip. + + Plain printable ASCII without leading/trailing whitespace passes verbatim; + anything else (control chars, non-ASCII, edge whitespace, or a value that + already looks like the sentinel) is base64-wrapped so the receiver can + recover the exact bytes. + """ if _HEADER_SAFE.fullmatch(value) and value == value.strip() and not _B64_SENTINEL.fullmatch(value): return value return f"=?base64?{base64.b64encode(value.encode('utf-8')).decode('ascii')}?=" +def decode_header_value(value: str | None) -> str | None: + """Inverse of :func:`encode_header_value`. + + Returns the value verbatim unless it carries the `=?base64?...?=` sentinel, + in which case the payload is decoded as UTF-8. A malformed sentinel (bad + base64 or bad UTF-8) yields `None` so a corrupt header never matches a body + value by accident. `None` in → `None` out so callers can pass + `headers.get(...)` directly. + """ + if value is None: + return None + m = _B64_SENTINEL.fullmatch(value) + if m is None: + return value + try: + return base64.b64decode(m.group("payload"), validate=True).decode("utf-8") + except (binascii.Error, UnicodeDecodeError): + return None + + +def find_invalid_x_mcp_header(input_schema: Any) -> str | None: + """Return a reason string if any `x-mcp-header` annotation in `input_schema` is invalid; else `None`. + + Walks every JSON Schema 2020-12 schema position. An annotation is valid + only when it sits on a property statically reachable from the root via a + chain of pure `properties` keys, names a non-empty RFC 9110 token, is on + an integer/string/boolean property, and is case-insensitively unique + across the whole schema. A `None` / non-mapping schema has no schema + positions and returns `None`. + """ + seen: dict[str, str] = {} + for path, schema in _walk_schema_positions(input_schema): + if X_MCP_HEADER_KEY not in schema: + continue + if not path: # None (off the pure-properties chain) or () (the root itself) + return f"{X_MCP_HEADER_KEY} found at a schema position not reachable via a pure `properties` chain" + where = ".".join(path) + header = schema[X_MCP_HEADER_KEY] + if not isinstance(header, str) or not _RFC9110_TOKEN.fullmatch(header): + return f"property {where!r}: {X_MCP_HEADER_KEY} {header!r} is not an RFC 9110 token" + prop_type = schema.get("type") + if not isinstance(prop_type, str) or prop_type not in _X_MCP_HEADER_PRIMITIVE_TYPES: + return ( + f"property {where!r}: {X_MCP_HEADER_KEY} is only permitted on " + f"integer/string/boolean properties (got {prop_type!r})" + ) + lower = header.lower() + if lower in seen: + return f"{X_MCP_HEADER_KEY} {header!r} on property {where!r} duplicates property {seen[lower]!r}" + seen[lower] = where + return None + + # INTERNAL_ERROR is deliberately unmapped (→ HTTP 200): the spec assigns no status to # -32603, and whether handler-origin errors get 5xx is an open S4 question — see TODO(L66). ERROR_CODE_HTTP_STATUS: Final[Mapping[int, int]] = MappingProxyType( @@ -76,7 +219,7 @@ def encode_header_value(value: str) -> str: METHOD_NOT_FOUND: 404, } ) -"""HTTP status to send for a JSON-RPC ``error.code``. +"""HTTP status to send for a JSON-RPC `error.code`. Consulted for classifier-origin *and* handler-origin errors, so one table decides the wire status regardless of where the error was produced. Unmapped @@ -88,7 +231,7 @@ def encode_header_value(value: str) -> str: class InboundModernRoute: """A modern-protocol request whose envelope passed every ladder rung. - ``client_info`` and ``client_capabilities`` are the raw envelope values; + `client_info` and `client_capabilities` are the raw envelope values; the classifier checks presence only, not shape. Method existence is not a ladder rung — kernel dispatch is the single source of truth for that. """ @@ -117,25 +260,26 @@ def classify_inbound_request( Rungs, in order — first failure wins: - 1. ``params._meta`` is a mapping carrying every reserved envelope key + 1. `params._meta` is a mapping carrying every reserved envelope key (protocol version, client info, client capabilities) → else :data:`~mcp_types.jsonrpc.INVALID_PARAMS`. - 2. When ``headers`` is given, its ``MCP-Protocol-Version`` entry equals - the envelope's protocol version → else - :data:`~mcp_types.jsonrpc.HEADER_MISMATCH`. Runs before the - supported-version rung so a client that disagrees with itself is told - so, rather than told the body's version is unsupported. - 3. The envelope's protocol version is in ``supported_modern_versions`` → + 2. When `headers` is given, `MCP-Protocol-Version` equals the envelope's + protocol version, `Mcp-Method` equals `body.method`, and — for the + methods in :data:`NAME_BEARING_METHODS` — `Mcp-Name` equals the named + body param → else :data:`~mcp_types.jsonrpc.HEADER_MISMATCH`. Runs + before the supported-version rung so a client that disagrees with itself + is told so, rather than told the body's version is unsupported. + 3. The envelope's protocol version is in `supported_modern_versions` → else :data:`~mcp_types.jsonrpc.UNSUPPORTED_PROTOCOL_VERSION` with - ``data = {"supported": [...], "requested": }``. + `data = {"supported": [...], "requested": }`. Method existence is *not* a rung: kernel dispatch owns that decision so custom-registered methods route and the answer lives in one place. Args: body: The decoded JSON-RPC request mapping. Envelope shape - (``jsonrpc`` / ``id``) is not checked here. - headers: Transport headers keyed by lowercase name, or ``None`` to + (`jsonrpc` / `id`) is not checked here. + headers: Transport headers keyed by lowercase name, or `None` to skip the header rung (non-HTTP callers). supported_modern_versions: Modern protocol revisions this server accepts on the per-request-envelope path. @@ -152,12 +296,27 @@ def classify_inbound_request( "client-capabilities envelope keys", ) - # TODO(L59): also validate Mcp-Method / Mcp-Name per SEP-2243 §Server Validation - if headers is not None and headers.get(MCP_PROTOCOL_VERSION_HEADER) != protocol_version: - return InboundLadderRejection( - code=HEADER_MISMATCH, - message=f"{MCP_PROTOCOL_VERSION_HEADER} header does not match the request envelope's protocol version", - ) + if headers is not None: + if headers.get(MCP_PROTOCOL_VERSION_HEADER) != protocol_version: + return InboundLadderRejection( + code=HEADER_MISMATCH, + message=f"{MCP_PROTOCOL_VERSION_HEADER} header does not match the request envelope's protocol version", + ) + method: Any = body.get("method") + if headers.get(MCP_METHOD_HEADER) != method: + return InboundLadderRejection( + code=HEADER_MISMATCH, + message=f"{MCP_METHOD_HEADER} header does not match the request body's method", + ) + name_key = NAME_BEARING_METHODS.get(method) + if name_key is not None: + # Rung 1 already proved body["params"] is a mapping. + body_value = body["params"].get(name_key) + if body_value is not None and decode_header_value(headers.get(MCP_NAME_HEADER)) != body_value: + return InboundLadderRejection( + code=HEADER_MISMATCH, + message=f"{MCP_NAME_HEADER} header does not match the request body's {name_key!r} parameter", + ) if protocol_version not in supported_modern_versions: return InboundLadderRejection( diff --git a/tests/client/test_client.py b/tests/client/test_client.py index cc3ff4d96..f869d1f1b 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -472,6 +472,38 @@ async def scripted_transport() -> AsyncIterator[TransportStreams]: assert methods_seen == ["server/discover", "initialize", "notifications/initialized"] +@pytest.mark.anyio +async def test_modern_list_tools_drops_tools_with_invalid_x_mcp_header_but_legacy_does_not() -> None: + """At 2026-07-28 the spec requires clients to exclude tools whose `x-mcp-header` + annotation is malformed; handshake-era sessions surface them unchanged. Two + tools are advertised — one valid, one with a non-RFC-9110-token header name — + and the modern client sees only the valid one.""" + valid = types.Tool( + name="ok", + input_schema={"type": "object", "properties": {"a": {"type": "string", "x-mcp-header": "Region"}}}, + ) + bad = types.Tool( + name="dropme", + input_schema={"type": "object", "properties": {"a": {"type": "string", "x-mcp-header": "bad name"}}}, + ) + + async def on_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[valid, bad]) + + server = Server("test", on_list_tools=on_list_tools) + + with anyio.fail_after(5): + async with Client(server) as client: + result = await client.list_tools() + assert [t.name for t in result.tools] == ["ok"] + + async with Client(server, mode="legacy") as client: + result = await client.list_tools() + assert [t.name for t in result.tools] == ["ok", "dropme"] + + def test_client_rejects_handshake_era_mode_at_construction() -> None: """A handshake-era protocol-version string passed as `mode=` is rejected by `__post_init__` with a hint to use `mode='legacy'` — the version-pin path is diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index e17f2f18f..9c83e213c 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -206,7 +206,7 @@ async def test_unsupported_protocol_version_rejection_body_contains_the_sniffed_ response = await http.post( "/mcp", json={"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {"_meta": meta}}, - headers=base_headers() | {"mcp-protocol-version": bad}, + headers=base_headers() | {"mcp-protocol-version": bad, "mcp-method": "tools/list"}, ) assert response.status_code == 400 diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index 0329f6836..edc3decbd 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -10,7 +10,7 @@ import annotated_types import pytest from dirty_equals import IsPartialDict -from mcp_types import CallToolResult +from mcp_types import CallToolResult, InputRequiredResult from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import InvalidSignature @@ -862,6 +862,29 @@ def func_returning_annotated_tool_call_result() -> Annotated[CallToolResult, Per assert isinstance(meta.convert_result(func_returning_annotated_tool_call_result()), CallToolResult) +def test_tool_call_result_annotated_unioned_with_input_required_result_is_equivalent_to_the_bare_annotated_form(): + """Stripping `InputRequiredResult` makes the residual behave exactly as if it were the + declared return annotation, including the `Annotated[CallToolResult, Model]` special case + — the schema derives from `Model` and `convert_result` validates `structured_content` + against it instead of wrapping the whole `CallToolResult`.""" + + class PersonClass(BaseModel): + name: str + + def fn_bare() -> Annotated[CallToolResult, PersonClass]: + return CallToolResult(content=[], structured_content={"name": "Brandon"}) + + def fn_iir() -> Annotated[CallToolResult, PersonClass] | InputRequiredResult: + return CallToolResult(content=[], structured_content={"name": "Brandon"}) + + bare = func_metadata(fn_bare) + iir = func_metadata(fn_iir) + assert iir.output_schema == bare.output_schema + assert iir.wrap_output == bare.wrap_output + assert isinstance(bare.convert_result(fn_bare()), CallToolResult) + assert isinstance(iir.convert_result(fn_iir()), CallToolResult) + + def test_tool_call_result_annotated_is_structured_and_invalid(): class PersonClass(BaseModel): name: str @@ -1038,7 +1061,9 @@ def func_with_aliases() -> ModelWithAliases: # pragma: no cover # Check that the actual output uses aliases too result = ModelWithAliases(**{"first": "hello", "second": "world"}) - structured_content = meta.convert_result(result).structured_content + converted = meta.convert_result(result) + assert isinstance(converted, CallToolResult) + structured_content = converted.structured_content assert structured_content is not None # The structured content should use aliases to match the schema @@ -1051,7 +1076,9 @@ def func_with_aliases() -> ModelWithAliases: # pragma: no cover # Also test the case where we have a model with defaults to ensure aliases work in all cases result_with_defaults = ModelWithAliases() # Uses default None values - structured_content_defaults = meta.convert_result(result_with_defaults).structured_content + converted_defaults = meta.convert_result(result_with_defaults) + assert isinstance(converted_defaults, CallToolResult) + structured_content_defaults = converted_defaults.structured_content assert structured_content_defaults is not None # Even with defaults, should use aliases in output @@ -1191,3 +1218,71 @@ def func_with_metadata() -> Annotated[int, Field(gt=1)]: ... # pragma: no branc assert meta.output_schema is not None assert meta.output_schema["properties"]["result"] == {"exclusiveMinimum": 1, "title": "Result", "type": "integer"} + + +def test_convert_result_passes_input_required_result_through_unchanged(): + def fn() -> str | InputRequiredResult: ... # pragma: no branch + + meta = func_metadata(fn) + irr = InputRequiredResult(request_state="opaque") + assert meta.convert_result(irr) is irr + + +def test_input_required_result_return_annotation_yields_no_output_schema(): + def fn() -> InputRequiredResult: ... # pragma: no branch + + meta = func_metadata(fn) + assert meta.output_schema is None + assert meta.output_model is None + + +def test_union_with_input_required_result_derives_schema_from_residual_arm(): + def fn() -> str | InputRequiredResult: ... # pragma: no branch + + meta = func_metadata(fn) + assert meta.output_schema is not None + assert meta.output_schema["properties"]["result"]["type"] == "string" + converted = meta.convert_result("hello") + assert isinstance(converted, CallToolResult) + assert converted.structured_content == {"result": "hello"} + irr = InputRequiredResult(request_state="opaque") + assert meta.convert_result(irr) is irr + + +def test_call_tool_result_unioned_with_input_required_result_is_accepted(): + def fn() -> CallToolResult | InputRequiredResult: ... # pragma: no branch + + meta = func_metadata(fn) + assert meta.output_schema is None + + +def test_basemodel_union_input_required_result_derives_model_schema(): + class Payload(BaseModel): + x: int + + def fn() -> Payload | InputRequiredResult: ... # pragma: no branch + + meta = func_metadata(fn) + assert meta.output_model is Payload + assert meta.wrap_output is False + assert meta.output_schema == Payload.model_json_schema() + + +def test_call_tool_result_in_union_with_input_required_result_is_still_rejected(): + def fn() -> CallToolResult | str | InputRequiredResult: ... # pragma: no branch + + with pytest.raises(InvalidSignature, match="CallToolResult cannot be used in Union"): + func_metadata(fn) + + +def test_union_of_only_input_required_subclasses_yields_no_output_schema(): + class StepA(InputRequiredResult): + pass + + class StepB(InputRequiredResult): + pass + + def fn() -> StepA | StepB: ... # pragma: no branch + + meta = func_metadata(fn) + assert meta.output_schema is None diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 4ea867480..47f3384a8 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -3,6 +3,7 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +import anyio import pytest from inline_snapshot import snapshot from mcp_types import ( @@ -11,15 +12,21 @@ AudioContent, BlobResourceContents, CallToolResult, + ClientCapabilities, Completion, CompletionArgument, CompletionContext, ContentBlock, + ElicitRequest, + ElicitRequestFormParams, + ElicitResult, EmbeddedResource, GetPromptResult, Icon, ImageContent, + InputRequiredResult, ListPromptsResult, + ListRootsRequest, Prompt, PromptArgument, PromptMessage, @@ -1570,3 +1577,98 @@ def get_user(user_id: str) -> str: assert exc_info.value.error.code == INVALID_PARAMS assert exc_info.value.error.data == {"uri": "resource://users/999"} + + +async def test_tool_returning_input_required_result_reaches_client_unchanged(): + mcp = MCPServer() + + @mcp.tool() + async def ask(ctx: Context) -> str | InputRequiredResult: + return InputRequiredResult(input_requests={"roots": ListRootsRequest()}, request_state="round-1") + + with anyio.fail_after(5): + async with Client(mcp, mode="2026-07-28") as client: + result = await client.call_tool("ask", allow_input_required=True) + + assert isinstance(result, InputRequiredResult) + assert result.request_state == "round-1" + assert result.input_requests is not None + assert result.input_requests["roots"].method == "roots/list" + + +async def test_tool_reads_input_responses_and_request_state_from_context_on_retry(): + mcp = MCPServer() + + @mcp.tool() + async def greet(ctx: Context) -> str | InputRequiredResult: + responses = ctx.input_responses + if responses and "who" in responses: + who = responses["who"] + assert isinstance(who, ElicitResult) and who.content is not None + return f"Hello, {who.content['name']}! (state={ctx.request_state})" + return InputRequiredResult( + input_requests={ + "who": ElicitRequest( + params=ElicitRequestFormParams( + message="What is your name?", + requested_schema={ + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + ) + ) + }, + request_state="r1", + ) + + with anyio.fail_after(5): + async with Client(mcp, mode="2026-07-28") as client: + r1 = await client.call_tool("greet", allow_input_required=True) + assert isinstance(r1, InputRequiredResult) + assert r1.input_requests is not None and "who" in r1.input_requests + + r2 = await client.call_tool( + "greet", + input_responses={"who": ElicitResult(action="accept", content={"name": "Alice"})}, + request_state=r1.request_state, + allow_input_required=True, + ) + assert isinstance(r2, CallToolResult) + block = r2.content[0] + assert isinstance(block, TextContent) + assert block.text == "Hello, Alice! (state=r1)" + + +async def test_context_exposes_client_capabilities_from_connection(): + mcp = MCPServer() + seen: list[ClientCapabilities | None] = [] + + @mcp.tool() + async def probe(ctx: Context) -> str: + seen.append(ctx.client_capabilities) + return "ok" + + with anyio.fail_after(5): + async with Client(mcp, mode="2026-07-28") as client: + await client.call_tool("probe") + + assert len(seen) == 1 + assert isinstance(seen[0], ClientCapabilities) + + +async def test_context_input_responses_and_request_state_are_none_on_initial_round(): + mcp = MCPServer() + captured: dict[str, Any] = {} + + @mcp.tool() + async def probe(ctx: Context) -> str: + captured["responses"] = ctx.input_responses + captured["state"] = ctx.request_state + return "ok" + + with anyio.fail_after(5): + async with Client(mcp, mode="2026-07-28") as client: + await client.call_tool("probe") + + assert captured == {"responses": None, "state": None} diff --git a/tests/server/test_streamable_http_modern.py b/tests/server/test_streamable_http_modern.py index 08b940107..0ba61cf39 100644 --- a/tests/server/test_streamable_http_modern.py +++ b/tests/server/test_streamable_http_modern.py @@ -15,6 +15,7 @@ from mcp_types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, + HEADER_MISMATCH, INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, @@ -39,7 +40,7 @@ ) from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError, NoBackChannelError -from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +from mcp.shared.inbound import MCP_METHOD_HEADER, MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER from mcp.shared.transport_context import TransportContext pytestmark = pytest.mark.anyio @@ -67,7 +68,10 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://testserver", - headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}, + headers={ + MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION, + "content-type": "application/json", + }, ) @@ -150,7 +154,7 @@ async def greet(ctx: ServerRequestContext, params: PaginatedRequestParams) -> di body["method"] = "custom/greet" body["params"]["_meta"][CLIENT_INFO_META_KEY] = "not-an-object" async with _asgi_client(server) as http: - response = await http.post("/mcp", json=body, headers={"content-type": "application/json"}) + response = await http.post("/mcp", json=body, headers={MCP_METHOD_HEADER: "custom/greet"}) assert response.status_code == 200 assert response.json()["result"] == {"ok": True} assert seen == [None] @@ -175,7 +179,7 @@ async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | with caplog.at_level(logging.ERROR, logger=runner.__name__): async with _asgi_client(Server("test", on_list_tools=list_tools)) as http: - response = await http.post("/mcp", json=_list_tools_body(), headers={"content-type": "application/json"}) + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) assert response.status_code == 200 assert response.json()["result"]["tools"] == [] @@ -203,7 +207,7 @@ async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | with anyio.fail_after(5), caplog.at_level(logging.WARNING, logger=runner.__name__): async with _asgi_client(Server("test", on_list_tools=list_tools)) as http: - response = await http.post("/mcp", json=_list_tools_body(), headers={"content-type": "application/json"}) + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "tools/list"}) # coverage.py on Python 3.11 misreports the lines below as unhit (the test passes there); # the shielded-cancel path inside the request task disrupts the tracer in this frame. assert response.status_code == 200 # pragma: lax no cover @@ -270,3 +274,30 @@ async def fail() -> dict[str, Any]: # Handler internals never reach the wire. assert "boom" not in reply.error.message assert "request handler raised" in caplog.text + + +# --- header cross-check at the wire -------------------------------------------- + + +async def test_handle_modern_request_rejects_mismatched_method_header_with_400_and_header_mismatch() -> None: + """Spec-mandated: an `Mcp-Method` header that disagrees with `body.method` is rejected at the + boundary as HTTP 400 with JSON-RPC error code HEADER_MISMATCH; the handler never runs.""" + async with _asgi_client(Server("test")) as http: + response = await http.post("/mcp", json=_list_tools_body(), headers={MCP_METHOD_HEADER: "prompts/list"}) + assert response.status_code == 400 + assert response.json()["error"]["code"] == HEADER_MISMATCH + + +async def test_handle_modern_request_rejects_mismatched_name_header_with_400_and_header_mismatch() -> None: + """Spec-mandated: for a name-bearing method, an `Mcp-Name` header that disagrees with the body's + named param is rejected as HTTP 400 with JSON-RPC error code HEADER_MISMATCH.""" + body = _list_tools_body() + body["method"] = "tools/call" + body["params"]["name"] = "real" + body["params"]["arguments"] = {} + async with _asgi_client(Server("test")) as http: + response = await http.post( + "/mcp", json=body, headers={MCP_METHOD_HEADER: "tools/call", MCP_NAME_HEADER: "wrong"} + ) + assert response.status_code == 400 + assert response.json()["error"]["code"] == HEADER_MISMATCH diff --git a/tests/shared/test_inbound.py b/tests/shared/test_inbound.py index 150cea6c2..93ab6ecc2 100644 --- a/tests/shared/test_inbound.py +++ b/tests/shared/test_inbound.py @@ -1,7 +1,7 @@ """Pure-function tests of :mod:`mcp.shared.inbound`. Independent verifier of the classifier: every ladder rung is exercised -pass+fail with no ``mcp.server`` / transport imports and no inlined error-code +pass+fail with no `mcp.server` / transport imports and no inlined error-code or protocol-version literals — all facts are imported from their one source. """ @@ -27,10 +27,16 @@ from mcp.shared.inbound import ( ERROR_CODE_HTTP_STATUS, + MCP_METHOD_HEADER, + MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER, + NAME_BEARING_METHODS, InboundLadderRejection, InboundModernRoute, classify_inbound_request, + decode_header_value, + encode_header_value, + find_invalid_x_mcp_header, ) CLIENT_INFO = {"name": "t", "version": "0"} @@ -42,10 +48,11 @@ def envelope( *, version: str = LATEST_MODERN_VERSION, drop: frozenset[str] = frozenset(), + extra_params: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Build a JSON-RPC body carrying a complete modern ``_meta`` envelope. + """Build a JSON-RPC body carrying a complete modern `_meta` envelope. - ``drop`` removes named envelope keys so rung-1 failures are driven from one + `drop` removes named envelope keys so rung-1 failures are driven from one table instead of repeating reserved-key constants per call site. """ meta: dict[str, Any] = { @@ -55,7 +62,22 @@ def envelope( } for key in drop: del meta[key] - return {"jsonrpc": "2.0", "id": 1, "method": method, "params": {"_meta": meta}} + params: dict[str, Any] = {"_meta": meta} + if extra_params: + params.update(extra_params) + return {"jsonrpc": "2.0", "id": 1, "method": method, "params": params} + + +def matching_headers(body: dict[str, Any]) -> dict[str, str]: + """The minimal lowercase HTTP header set that agrees with `body` for rung 2.""" + headers = { + MCP_PROTOCOL_VERSION_HEADER: body["params"]["_meta"][PROTOCOL_VERSION_META_KEY], + MCP_METHOD_HEADER: body["method"], + } + name_key = NAME_BEARING_METHODS.get(body["method"]) + if name_key is not None and name_key in body["params"]: + headers[MCP_NAME_HEADER] = encode_header_value(body["params"][name_key]) + return headers def assert_rejected(result: object, code: int) -> InboundLadderRejection: @@ -78,7 +100,7 @@ def assert_rejected(result: object, code: int) -> InboundLadderRejection: ], ) def test_envelope_rung_rejects_missing_keys(body: dict[str, Any]) -> None: - """Spec-mandated: a modern request lacking any of the three reserved ``_meta`` keys is rejected INVALID_PARAMS.""" + """Spec-mandated: a modern request lacking any of the three reserved `_meta` keys is rejected INVALID_PARAMS.""" rejection = assert_rejected(classify_inbound_request(body), INVALID_PARAMS) assert rejection.data is None @@ -94,7 +116,7 @@ def test_envelope_rung_rejects_missing_keys(body: dict[str, Any]) -> None: ], ) def test_envelope_rung_rejects_non_mapping_shapes(body: dict[str, Any]) -> None: - """Spec-mandated: non-mapping ``params`` / ``_meta`` cannot carry the envelope and reject INVALID_PARAMS.""" + """Spec-mandated: non-mapping `params` / `_meta` cannot carry the envelope and reject INVALID_PARAMS.""" assert_rejected(classify_inbound_request(body), INVALID_PARAMS) @@ -102,7 +124,7 @@ def test_envelope_rung_rejects_non_mapping_shapes(body: dict[str, Any]) -> None: def test_version_rung_rejects_unsupported_with_data_shape() -> None: - """Spec-mandated: an envelope version outside the modern set rejects with the ``supported``/``requested`` data.""" + """Spec-mandated: an envelope version outside the modern set rejects with the `supported`/`requested` data.""" rejection = assert_rejected( classify_inbound_request(envelope(version=LATEST_HANDSHAKE_VERSION)), UNSUPPORTED_PROTOCOL_VERSION, @@ -114,7 +136,7 @@ def test_version_rung_rejects_unsupported_with_data_shape() -> None: def test_version_rung_data_reflects_supplied_supported_list() -> None: - """SDK-defined: the caller-supplied ``supported_modern_versions`` is what rejection ``data.supported`` echoes.""" + """SDK-defined: the caller-supplied `supported_modern_versions` is what rejection `data.supported` echoes.""" custom = (LATEST_HANDSHAKE_VERSION,) rejection = assert_rejected( classify_inbound_request(envelope(), supported_modern_versions=custom), @@ -127,14 +149,15 @@ def test_version_rung_data_reflects_supplied_supported_list() -> None: def test_header_rung_does_not_reject_when_headers_arg_is_none() -> None: - """SDK-defined: ``headers=None`` (non-HTTP transports) means rung 3 has nothing to check and the ladder proceeds.""" + """SDK-defined: `headers=None` (non-HTTP transports) means rung 3 has nothing to check and the ladder proceeds.""" result = classify_inbound_request(envelope(), headers=None) assert isinstance(result, InboundModernRoute) def test_header_rung_passes_when_header_matches_envelope() -> None: """Spec-mandated: an HTTP version header equal to the envelope version passes rung 3.""" - result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) + body = envelope() + result = classify_inbound_request(body, headers=matching_headers(body)) assert isinstance(result, InboundModernRoute) @@ -150,12 +173,78 @@ def test_header_rung_rejects_on_disagreement(headers: dict[str, str]) -> None: assert_rejected(classify_inbound_request(envelope(), headers=headers), HEADER_MISMATCH) +@pytest.mark.parametrize( + "override", + [ + pytest.param({MCP_METHOD_HEADER: "prompts/list"}, id="method-mismatch"), + pytest.param({MCP_METHOD_HEADER: "TOOLS/LIST"}, id="method-case-mismatch"), + ], +) +def test_header_rung_rejects_method_header_disagreement(override: dict[str, str]) -> None: + """Spec-mandated: `Mcp-Method` must equal `body.method` exactly (case-sensitive) → else HEADER_MISMATCH.""" + body = envelope() + rejection = assert_rejected( + classify_inbound_request(body, headers=matching_headers(body) | override), HEADER_MISMATCH + ) + assert MCP_METHOD_HEADER in rejection.message + + +def test_header_rung_rejects_missing_method_header() -> None: + """Spec-mandated: an HTTP request on the modern path without `Mcp-Method` is HEADER_MISMATCH.""" + body = envelope() + headers = matching_headers(body) + del headers[MCP_METHOD_HEADER] + assert_rejected(classify_inbound_request(body, headers=headers), HEADER_MISMATCH) + + +@pytest.mark.parametrize( + ("method", "name_key"), + [(m, k) for m, k in NAME_BEARING_METHODS.items()], +) +def test_header_rung_rejects_missing_or_mismatched_name_header_for_name_bearing_methods( + method: str, name_key: str +) -> None: + """Spec-mandated: when the body carries the named param, `Mcp-Name` must be present and equal it.""" + body = envelope(method, extra_params={name_key: "expected"}) + headers = matching_headers(body) + # Mismatch + assert_rejected(classify_inbound_request(body, headers=headers | {MCP_NAME_HEADER: "wrong"}), HEADER_MISMATCH) + # Absent + del headers[MCP_NAME_HEADER] + assert_rejected(classify_inbound_request(body, headers=headers), HEADER_MISMATCH) + + +def test_header_rung_decodes_base64_sentinel_before_comparing_name() -> None: + """Spec-mandated: servers MUST decode the `=?base64?...?=` sentinel before comparing `Mcp-Name`.""" + body = envelope("tools/call", extra_params={"name": "résumé"}) + headers = matching_headers(body) + assert headers[MCP_NAME_HEADER].startswith("=?base64?") + result = classify_inbound_request(body, headers=headers) + assert isinstance(result, InboundModernRoute) + + +def test_header_rung_does_not_require_name_header_for_non_name_bearing_method() -> None: + """SDK-defined: a method outside `NAME_BEARING_METHODS` ignores `Mcp-Name` entirely.""" + body = envelope("tools/list") + result = classify_inbound_request(body, headers=matching_headers(body) | {MCP_NAME_HEADER: "anything"}) + assert isinstance(result, InboundModernRoute) + + +def test_header_rung_does_not_require_name_header_when_body_omits_the_named_param() -> None: + """SDK-defined: a name-bearing method whose body lacks the named param skips the `Mcp-Name` + check — the param's absence is INVALID_PARAMS later, not HEADER_MISMATCH here.""" + body = envelope("tools/call") + result = classify_inbound_request(body, headers=matching_headers(body)) + assert isinstance(result, InboundModernRoute) + + # --- all rungs pass ------------------------------------------------------------ def test_all_rungs_pass_yields_route() -> None: """Spec-mandated: a complete envelope at a supported version with agreeing header routes, surfacing the envelope.""" - result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) + body = envelope() + result = classify_inbound_request(body, headers=matching_headers(body)) assert isinstance(result, InboundModernRoute) assert result.protocol_version == LATEST_MODERN_VERSION assert result.client_info == CLIENT_INFO @@ -165,7 +254,8 @@ def test_all_rungs_pass_yields_route() -> None: @pytest.mark.parametrize("method", ["initialize", "myorg/custom", "does/not/exist"]) def test_classifier_passes_unknown_method_through_to_route(method: str) -> None: """SDK-defined: the classifier does not gate on method — kernel dispatch is the single owner of that decision.""" - result = classify_inbound_request(envelope(method), headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) + body = envelope(method) + result = classify_inbound_request(body, headers=matching_headers(body)) assert isinstance(result, InboundModernRoute) @@ -215,3 +305,173 @@ def test_verdict_dataclasses_are_frozen() -> None: for verdict in (route, rejection): with pytest.raises(dataclasses.FrozenInstanceError): setattr(verdict, "message", "mutated") + + +# --- header-value codec -------------------------------------------------------- + + +@pytest.mark.parametrize( + "raw", + ["plain", "with internal space", "", " edge-ws ", "résumé", "a\r\nb", "=?base64?Zm9v?="], +) +def test_decode_header_value_round_trips_encode(raw: str) -> None: + """SDK-defined: `decode_header_value` is the exact inverse of `encode_header_value` over the full input domain.""" + assert decode_header_value(encode_header_value(raw)) == raw + + +def test_decode_header_value_passes_none_and_plain_through() -> None: + """SDK-defined: `None` in → `None` out so callers can pass `headers.get(...)` directly; plain stays verbatim.""" + assert decode_header_value(None) is None + assert decode_header_value("plain") == "plain" + + +@pytest.mark.parametrize("bad", ["=?base64?not base64!?=", "=?base64?gA==?="]) +def test_decode_header_value_returns_none_for_malformed_sentinel(bad: str) -> None: + """SDK-defined: a sentinel with bad base64 or bad UTF-8 decodes to `None`, so it can never match a body value.""" + assert decode_header_value(bad) is None + + +# --- NAME_BEARING_METHODS ------------------------------------------------------ + + +def test_name_bearing_methods_table_matches_spec() -> None: + """Spec-mandated: pins the method → name-param table the client emit and server validate share.""" + assert NAME_BEARING_METHODS == {"tools/call": "name", "prompts/get": "name", "resources/read": "uri"} + + +# --- find_invalid_x_mcp_header ------------------------------------------------- + + +def _schema(**props: Any) -> dict[str, Any]: + return {"type": "object", "properties": props} + + +@pytest.mark.parametrize( + "input_schema", + [ + pytest.param(None, id="none"), + pytest.param("not-a-mapping", id="non-mapping"), + pytest.param({"type": "object"}, id="no-properties"), + pytest.param({"type": "object", "properties": "not-a-mapping"}, id="properties-non-mapping"), + pytest.param(_schema(a={"type": "string"}), id="no-annotation"), + pytest.param(_schema(a={"type": "string", "x-mcp-header": "Region"}), id="valid-string"), + pytest.param(_schema(a={"type": "integer", "x-mcp-header": "Count"}), id="valid-integer"), + pytest.param(_schema(a={"type": "boolean", "x-mcp-header": "Flag"}), id="valid-boolean"), + pytest.param( + _schema(a={"type": "string", "x-mcp-header": "A"}, b={"type": "string", "x-mcp-header": "B"}), + id="two-distinct", + ), + pytest.param(_schema(a="not-a-mapping", b={"type": "string", "x-mcp-header": "B"}), id="non-mapping-prop"), + pytest.param( + _schema(outer={"type": "object", "properties": {"r": {"type": "string", "x-mcp-header": "R"}}}), + id="nested-on-properties-chain", + ), + pytest.param( + _schema(a={"type": "string", "default": {"x-mcp-header": "ignored"}}), + id="annotation-lookalike-in-default-is-data", + ), + pytest.param( + _schema(a={"type": "string", "examples": [{"x-mcp-header": "ignored"}]}), + id="annotation-lookalike-in-examples-is-data", + ), + pytest.param( + _schema(a={"type": "string", "const": {"x-mcp-header": "ignored"}}), + id="annotation-lookalike-in-const-is-data", + ), + pytest.param( + {"properties": {"a": {"type": "string", "x-mcp-header": "R"}}, "$ref": "#/$defs/loop"}, + id="ref-is-not-dereferenced", + ), + pytest.param( + {"type": "object", "allOf": 0, "anyOf": [], "$defs": 0, "patternProperties": {}}, + id="malformed-or-empty-applicators-ignored", + ), + ], +) +def test_find_invalid_x_mcp_header_accepts_valid_or_absent_annotations(input_schema: Any) -> None: + """Spec-mandated: a schema without annotations, or with annotations that are RFC 9110 tokens on + integer/string/boolean properties reachable via a pure `properties` chain and case-insensitively + unique across the whole schema, is valid.""" + assert find_invalid_x_mcp_header(input_schema) is None + + +@pytest.mark.parametrize( + "input_schema", + [ + pytest.param(_schema(a={"type": "string", "x-mcp-header": ""}), id="empty"), + pytest.param(_schema(a={"type": "string", "x-mcp-header": "My Region"}), id="space"), + pytest.param(_schema(a={"type": "string", "x-mcp-header": "Region:Primary"}), id="colon"), + pytest.param(_schema(a={"type": "string", "x-mcp-header": "Région"}), id="non-ascii"), + pytest.param(_schema(a={"type": "string", "x-mcp-header": "Region\t1"}), id="control-char"), + pytest.param(_schema(a={"type": "string", "x-mcp-header": 42}), id="non-string"), + pytest.param(_schema(a={"type": "object", "x-mcp-header": "Data"}), id="on-object"), + pytest.param(_schema(a={"type": "array", "x-mcp-header": "Items"}), id="on-array"), + pytest.param(_schema(a={"type": "null", "x-mcp-header": "Nil"}), id="on-null"), + pytest.param(_schema(a={"type": "number", "x-mcp-header": "Ratio"}), id="on-number"), + pytest.param(_schema(a={"type": ["string", "null"], "x-mcp-header": "Maybe"}), id="array-type"), + pytest.param(_schema(a={"type": {"not": "valid"}, "x-mcp-header": "Bad"}), id="dict-type"), + pytest.param(_schema(a={"x-mcp-header": "NoType"}), id="missing-type"), + pytest.param( + _schema(a={"type": "string", "x-mcp-header": "Region"}, b={"type": "string", "x-mcp-header": "Region"}), + id="duplicate-same-case", + ), + pytest.param( + _schema(a={"type": "string", "x-mcp-header": "MyField"}, b={"type": "string", "x-mcp-header": "myfield"}), + id="duplicate-diff-case", + ), + pytest.param( + _schema(a={"type": "array", "items": {"type": "string", "x-mcp-header": "X"}}), + id="under-items", + ), + pytest.param( + {"allOf": [{"properties": {"a": {"type": "string", "x-mcp-header": "X"}}}]}, + id="under-allOf", + ), + pytest.param( + {"oneOf": [{"type": "string", "x-mcp-header": "X"}]}, + id="under-oneOf", + ), + pytest.param( + _schema(a={"if": {"type": "string", "x-mcp-header": "X"}}), + id="under-if", + ), + pytest.param( + {"$defs": {"T": {"type": "string", "x-mcp-header": "X"}}, "properties": {}}, + id="under-defs", + ), + pytest.param( + {"patternProperties": {"^a": {"type": "string", "x-mcp-header": "X"}}}, + id="under-patternProperties", + ), + pytest.param( + {"type": "string", "x-mcp-header": "X"}, + id="on-root-schema", + ), + pytest.param( + _schema( + a={"type": "string", "x-mcp-header": "Region"}, + o={"type": "object", "properties": {"b": {"type": "string", "x-mcp-header": "region"}}}, + ), + id="duplicate-across-nesting-levels", + ), + pytest.param( + _schema(outer={"type": "object", "properties": {"r": {"type": "string", "x-mcp-header": "bad name"}}}), + id="nested-bad-token", + ), + pytest.param( + _schema(outer={"type": "object", "properties": {"r": {"type": "object", "x-mcp-header": "R"}}}), + id="nested-non-primitive", + ), + ], +) +def test_find_invalid_x_mcp_header_rejects_malformed_annotations(input_schema: dict[str, Any]) -> None: + """Spec-mandated: empty / non-token / non-primitive / off-chain / duplicate `x-mcp-header` + annotations yield a reason string.""" + assert isinstance(find_invalid_x_mcp_header(input_schema), str) + + +def test_find_invalid_x_mcp_header_reports_dotted_path_for_nested_property() -> None: + """SDK-defined: the reason string names the nested property by its dotted `properties` path.""" + schema = _schema(outer={"type": "object", "properties": {"r": {"type": "object", "x-mcp-header": "R"}}}) + reason = find_invalid_x_mcp_header(schema) + assert reason is not None and "'outer.r'" in reason