From dc84bb36ff1facde78008853f93b79f473e4ae0a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 11:26:17 +0200 Subject: [PATCH 1/5] Drive resolver elicitation over the 2026-07-28 input_required flow Resolvers that return Elicit[T] now negotiate the transport by protocol version: at >= 2026-07-28 the framework returns an InputRequiredResult carrying the batched questions and resumes when the client retries with input_responses/request_state; at <= 2025-11-25 it keeps the synchronous ctx.elicit() request. Author-facing code (Resolve/Elicit) is unchanged. resolve_arguments becomes a resumable DAG walk: it reads ctx.input_responses / ctx.request_state, memoizes resolver outcomes by a process-stable module:qualname key, batches independent pending elicitations into one round, serializes dependent ones across rounds, and carries resolved outcomes in request_state so each resolver resolves once per logical call. Outcomes restored from request_state are re-validated into their model via the Elicit[T] return arm. request_state is client-trusted for now (HMAC sealing is a follow-up). Add a render_elicitation_schema helper to elicitation.py, MRTR-loop and codec tests, and document the transport in the migration guide. --- docs/migration.md | 2 + src/mcp/server/elicitation.py | 15 +- src/mcp/server/mcpserver/resolve.py | 238 +++++++++++++++++++--- src/mcp/server/mcpserver/tools/base.py | 11 +- tests/server/mcpserver/test_resolve.py | 270 ++++++++++++++++++++++++- 5 files changed, 500 insertions(+), 36 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 19475ee27..fb3784681 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1515,6 +1515,8 @@ async def delete_folder( The `confirm_delete` resolver reads the tool's own `path` argument by name, lists the folder, and only elicits when the folder is non-empty - an empty folder resolves to `Confirm(ok=True)` with no round-trip to the client. Because `delete_folder` annotates the result union, it handles every outcome: the user accepting and confirming, accepting but declining to delete (`ok=False`), declining the elicitation, or cancelling it. +The framework drives elicitation over whichever transport the negotiated protocol provides, so the resolver and tool code above is unchanged either way. At `2026-07-28` and later it returns an `InputRequiredResult` carrying the questions and resumes when the client retries `call_tool(..., input_responses=..., request_state=...)` (independent resolvers are batched into one round; a resolver that depends on another's answer is asked in a later round). At `2025-11-25` and earlier it issues a synchronous `elicitation/create` request mid-call. Resolved outcomes are carried in `request_state` across rounds so each resolver resolves once per call. + Resolved parameters are omitted from the tool's input schema, so the client never supplies them. Resolver parameters that cannot be classified, and cyclic resolver dependencies, raise at registration time. ## Need Help? diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index c6faf0065..2f548f64e 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -87,6 +87,18 @@ def _validate_rendered_properties(json_schema: dict[str, Any]) -> None: ) from None +def render_elicitation_schema(schema: type[BaseModel]) -> dict[str, Any]: + """Render a model as the spec-valid `requested_schema` for an elicitation. + + Raises: + TypeError: If a field renders as something the spec's + `PrimitiveSchemaDefinition` does not accept. + """ + json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) + _validate_rendered_properties(json_schema) + return json_schema + + async def elicit_with_validation( session: ServerSession, message: str, @@ -103,8 +115,7 @@ async def elicit_with_validation( For sensitive data like credentials or OAuth flows, use elicit_url() instead. """ - json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) - _validate_rendered_properties(json_schema) + json_schema = render_elicitation_schema(schema) result = await session.elicit_form( message=message, diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 09c477599..2bc6e15d4 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -7,24 +7,40 @@ `Elicit[T]` to ask the client; the framework runs the elicitation and injects the answer. +The framework picks the elicitation transport from the negotiated protocol. At +>= 2026-07-28 it returns an `InputRequiredResult` carrying the batched questions +and resumes when the client retries with `input_responses`/`request_state` +(independent resolvers are asked in one round; a resolver depending on another's +answer is asked in a later round). At <= 2025-11-25 it issues a synchronous +`elicitation/create` request mid-call. Resolved outcomes are carried in +`request_state` across rounds, so each resolver resolves once per logical call. + Whether the consumer receives the unwrapped model or the full `ElicitationResult` union is decided by the consumer's annotation: - `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. - `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the full outcome; the consumer branches on accept/decline/cancel. - -Each resolver runs at most once per `tools/call` (memoized by function identity). """ from __future__ import annotations import inspect +import json import typing from collections.abc import Callable, Hashable, Mapping from typing import Annotated, Any, Generic, cast, get_args, get_origin import anyio.to_thread +from mcp_types import ( + ElicitRequest, + ElicitRequestFormParams, + ElicitResult, + InputRequests, + InputRequiredResult, + InputResponses, +) +from mcp_types.version import LATEST_MODERN_VERSION, is_version_at_least from pydantic import BaseModel from typing_extensions import TypeVar @@ -33,6 +49,7 @@ CancelledElicitation, DeclinedElicitation, ElicitationResult, + render_elicitation_schema, ) from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError @@ -43,6 +60,11 @@ # The union members the framework injects when a consumer opts into the outcome. _ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation) +# First protocol revision whose `tools/call` carries elicitation inside +# `InputRequiredResult` rather than as a standalone server-to-client request. +_INPUT_REQUIRED_VERSION = LATEST_MODERN_VERSION # "2026-07-28" +_STATE_VERSION = 1 + class Resolve: """Marker for `Annotated[T, Resolve(fn)]`: fill the parameter by running `fn`.""" @@ -79,10 +101,19 @@ def __init__(self, kind: str, resolve: Resolve | None = None, wants_union: bool class _ResolverPlan: """A resolver's parameters and whether it is async, analyzed once.""" - def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool) -> None: + def __init__( + self, + fn: Callable[..., Any], + params: dict[str, _ParamPlan], + is_async: bool, + elicit_schema: type[BaseModel] | None, + ) -> None: self.fn = fn self.params = params self.is_async = is_async + # The `T` from the resolver's `Elicit[T]` return arm, if annotated. Used to + # re-validate an outcome restored from `request_state` into a model. + self.elicit_schema = elicit_schema def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: @@ -125,6 +156,21 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, return resolved +def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: + """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. + + Lets an outcome restored from `request_state` (a plain dict) be re-validated + into its model so dependent resolvers and tools receive a typed value. + """ + candidates = get_args(return_annotation) if get_origin(return_annotation) is not None else (return_annotation,) + for candidate in candidates: + if get_origin(candidate) is Elicit: + schema = get_args(candidate)[0] + if isinstance(schema, type) and issubclass(schema, BaseModel): # pragma: no branch + return schema + return None + + def _wants_union(type_arg: Any) -> bool: """True when `type_arg` is an `ElicitationResult` member (or a union of them). @@ -202,7 +248,7 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) - plans[key] = _ResolverPlan(fn, params, is_async_callable(fn)) + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), _elicit_return_schema(hints.get("return"))) for dep in nested: analyze(dep, stack + (key,)) @@ -226,50 +272,111 @@ def _is_context_annotation(annotation: Any) -> bool: return any(isinstance(c, type) and issubclass(c, Context) for c in candidates) +class _Pending(Exception): + """Internal: a resolver needs client input not yet available this round.""" + + +class _Resolution: + """Per-`tools/call` resolution state, shared across the DAG walk. + + `input_required` selects the transport: at >= 2026-07-28 elicitations are + batched into `pending` and surfaced as an `InputRequiredResult`; at older + revisions each `Elicit` is answered synchronously via `ctx.elicit`. + """ + + def __init__( + self, + plans: Mapping[Hashable, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], + input_required: bool, + ) -> None: + self.plans = plans + self.tool_args = tool_args + self.context = context + self.input_required = input_required + self.answers: InputResponses = context.input_responses or {} if input_required else {} + self.state = _decode_state(context.request_state) if input_required else {} + self.cache: dict[str, ElicitationResult[Any]] = {} + self.pending: dict[str, ElicitRequest] = {} + + +def _state_key(fn: Callable[..., Any]) -> str: + """Process-stable wire key for a resolver. + + `id`-based keys aren't stable across `input_required` rounds (a retry may land + on a different worker), so memoize and key `input_requests`/`request_state` by + the resolver's `module:qualname`. Two consumers of the same resolver therefore + share one cache entry, one question, and one stored outcome. + """ + return f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)}" + + async def resolve_arguments( resolved_params: Mapping[str, tuple[Resolve, bool]], plans: Mapping[Hashable, _ResolverPlan], tool_args: Mapping[str, Any], context: Context[Any, Any], -) -> dict[str, Any]: +) -> dict[str, Any] | InputRequiredResult: """Resolve every `Resolve`-marked tool parameter into a concrete value. - Each resolver runs at most once (memoized by function identity). Returns a - mapping of tool parameter name to the value to inject. + Returns the mapping of tool parameter name to injected value when every + resolver is satisfied. When a resolver still needs client input (and the + negotiated protocol is >= 2026-07-28), returns an `InputRequiredResult` + carrying the batched questions instead; the tool body is not run. + + Each resolver runs at most once per logical call - across multiple + `input_required` rounds, resolved outcomes are carried in `request_state`. Raises: ToolError: If an elicited value is declined or cancelled and the consumer asked for the unwrapped model (rather than the result union). """ - cache: dict[Hashable, ElicitationResult[Any]] = {} + res = _Resolution(plans, tool_args, context, uses_input_required(context.request_context.protocol_version)) injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): - outcome = await _resolve(marker.fn, plans, tool_args, context, cache) + try: + outcome = await _resolve(marker.fn, res) + except _Pending: + continue injected[name] = outcome if wants_union else _unwrap(outcome, name) + + if res.pending: + return InputRequiredResult( + input_requests=cast("InputRequests", res.pending), + request_state=_encode_state(res.cache), + ) return injected -async def _resolve( - fn: Callable[..., Any], - plans: Mapping[Hashable, _ResolverPlan], - tool_args: Mapping[str, Any], - context: Context[Any, Any], - cache: dict[Hashable, ElicitationResult[Any]], -) -> ElicitationResult[Any]: - key = _resolver_key(fn) - if key in cache: - return cache[key] +async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResult[Any]: + """Resolve one resolver, memoized by its process-stable state key. + + Raises `_Pending` when the resolver (or one of its dependencies) needs client + input that has not arrived yet. + """ + key = _state_key(fn) + if key in res.cache: + return res.cache[key] + if key in res.pending: + # Already asked this round by another consumer; don't run the resolver again. + raise _Pending + + plan = res.plans[_resolver_key(fn)] + if key in res.state: + outcome = _outcome_from_state(res.state[key], plan.elicit_schema) + res.cache[key] = outcome + return outcome - plan = plans[key] kwargs: dict[str, Any] = {} for param_name, param_plan in plan.params.items(): if param_plan.kind == "context": - kwargs[param_name] = context + kwargs[param_name] = res.context elif param_plan.kind == "by_name": - kwargs[param_name] = tool_args[param_name] + kwargs[param_name] = res.tool_args[param_name] else: assert param_plan.resolve is not None - dep_outcome = await _resolve(param_plan.resolve.fn, plans, tool_args, context, cache) + dep_outcome = await _resolve(param_plan.resolve.fn, res) kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) if plan.is_async: @@ -277,25 +384,102 @@ async def _resolve( else: result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) - outcome: ElicitationResult[Any] if isinstance(result, Elicit): - elicit = cast("Elicit[BaseModel]", result) - outcome = await context.elicit(elicit.message, elicit.schema) + outcome = await _elicit(cast("Elicit[BaseModel]", result), key, res) else: # A resolver may return any type (not just `BaseModel`); `model_construct` # wraps it as an accepted result without validating against the schema bound. outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) - cache[key] = outcome + res.cache[key] = outcome return outcome +async def _elicit(elicit: Elicit[BaseModel], key: str, res: _Resolution) -> ElicitationResult[Any]: + """Turn a resolver's `Elicit` into an outcome via the negotiated transport.""" + if not res.input_required: + return await res.context.elicit(elicit.message, elicit.schema) + + answer = res.answers.get(key) + if answer is None: + res.pending[key] = _elicit_request(elicit) + raise _Pending + if not isinstance(answer, ElicitResult): + raise ToolError(f"Resolver {key!r} received a non-elicitation response") + if answer.action == "accept" and answer.content is not None: + return AcceptedElicitation(data=elicit.schema.model_validate(answer.content)) + if answer.action == "decline": + return DeclinedElicitation() + return CancelledElicitation() + + def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: if isinstance(outcome, AcceptedElicitation): return outcome.data raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") +def uses_input_required(protocol_version: str | None) -> bool: + """True when this request must elicit via `InputRequiredResult` (>= 2026-07-28). + + Older revisions still carry a standalone `elicitation/create` server-to-client + request, so the framework keeps the synchronous `ctx.elicit()` path for them. + """ + return protocol_version is not None and is_version_at_least(protocol_version, _INPUT_REQUIRED_VERSION) + + +def _elicit_request(elicit: Elicit[Any]) -> ElicitRequest: + """Render an `Elicit[T]` as the embedded `elicitation/create` request for `input_requests`.""" + json_schema = render_elicitation_schema(elicit.schema) + return ElicitRequest(params=ElicitRequestFormParams(message=elicit.message, requested_schema=json_schema)) + + +def _decode_state(request_state: str | None) -> dict[str, dict[str, Any]]: + """Decode the per-call resolution progress from `request_state`. + + `request_state` is client-trusted (integrity sealing is a follow-up); decode + defensively and treat anything malformed as "no progress yet". + """ + if not request_state: + return {} + try: + decoded: Any = json.loads(request_state) + except json.JSONDecodeError: + return {} + if not isinstance(decoded, dict): + return {} + payload = cast("dict[str, Any]", decoded) + if payload.get("v") != _STATE_VERSION: + return {} + outcomes = payload.get("outcomes") + return cast("dict[str, dict[str, Any]]", outcomes) if isinstance(outcomes, dict) else {} + + +def _encode_state(outcomes: Mapping[str, ElicitationResult[Any]]) -> str: + """Encode resolved outcomes (keyed by resolver path) for the next round.""" + encoded: dict[str, dict[str, Any]] = {} + for path, outcome in outcomes.items(): + entry: dict[str, Any] = {"action": outcome.action} + if isinstance(outcome, AcceptedElicitation): + data = outcome.data + entry["data"] = data.model_dump(mode="json") if isinstance(data, BaseModel) else data + encoded[path] = entry + return json.dumps({"v": _STATE_VERSION, "outcomes": encoded}) + + +def _outcome_from_state(entry: Mapping[str, Any], schema: type[BaseModel] | None) -> ElicitationResult[Any]: + """Rebuild an `ElicitationResult` from a decoded `request_state` entry.""" + action = entry.get("action") + if action == "decline": + return DeclinedElicitation() + if action == "cancel": + return CancelledElicitation() + data = entry.get("data") + if schema is not None and isinstance(data, dict): + data = schema.model_validate(data) + return cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=data)) + + __all__ = [ "Resolve", "Elicit", diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 6aab3c777..50d28f574 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -4,7 +4,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from mcp_types import Icon, ToolAnnotations +from mcp_types import Icon, InputRequiredResult, ToolAnnotations from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import ToolError @@ -135,9 +135,12 @@ async def run( pre_validated: dict[str, Any] | None = None if self.resolved_params: pre_validated = self.fn_metadata.validate_arguments(arguments) - pass_directly |= await resolve_arguments( - self.resolved_params, self.resolver_plans, pre_validated, context - ) + resolved = await resolve_arguments(self.resolved_params, self.resolver_plans, pre_validated, context) + if isinstance(resolved, InputRequiredResult): + # A resolver still needs client input (>= 2026-07-28): surface the + # batched questions instead of running the tool body this round. + return self.fn_metadata.convert_result(resolved) if convert_result else resolved + pass_directly |= resolved result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 3970893d2..61ec290d0 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,9 +1,17 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" -from typing import Annotated, Literal +from collections.abc import Callable +from typing import Annotated, Literal, cast import pytest -from mcp_types import ElicitRequestParams, ElicitResult, TextContent +from mcp_types import ( + CallToolResult, + ElicitRequestParams, + ElicitResult, + InputRequiredResult, + InputResponses, + TextContent, +) from pydantic import BaseModel, Field from mcp import Client @@ -19,7 +27,15 @@ Resolve, ) from mcp.server.mcpserver.exceptions import InvalidSignature -from mcp.server.mcpserver.resolve import _resolver_key, find_resolved_parameters +from mcp.server.mcpserver.resolve import ( + _decode_state, + _elicit_return_schema, + _encode_state, + _outcome_from_state, + _resolver_key, + find_resolved_parameters, + uses_input_required, +) from mcp.server.mcpserver.tools.base import Tool @@ -53,6 +69,36 @@ async def _text(client: Client, tool: str, args: dict[str, object]) -> str: return result.content[0].text +async def _drive_mrtr( + client: Client, + tool: str, + args: dict[str, object], + answer: Callable[[str, ElicitRequestParams], ElicitResult], + max_rounds: int = 10, +) -> CallToolResult: + """Drive the 2026-07-28 `input_required` loop to completion. + + Re-invokes `tools/call` with `input_responses`/`request_state` until the + server returns a final `CallToolResult`, fulfilling each pending request via + `answer(key, request_params)`. + """ + responses: InputResponses | None = None + state: str | None = None + for _ in range(max_rounds): + result = await client.call_tool( + tool, args, input_responses=responses, request_state=state, allow_input_required=True + ) + if isinstance(result, CallToolResult): + return result + assert isinstance(result, InputRequiredResult) + assert result.input_requests is not None + responses = { + key: answer(key, cast(ElicitRequestParams, req.params)) for key, req in result.input_requests.items() + } + state = result.request_state + raise AssertionError("input_required loop did not converge") # pragma: no cover + + @pytest.mark.anyio async def test_resolver_returns_value_directly_without_eliciting(): mcp = MCPServer(name="Direct") @@ -543,3 +589,221 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: assert await _text(client, "delete_folder", {"path": "/docs"}) == expected assert ("/docs" in fs) is (expected != "deleted /docs") + + +@pytest.mark.anyio +async def test_input_required_first_round_returns_the_question(): + mcp, fs = _delete_folder_server() + fs["/docs"] = ["a.txt", "b.txt"] + + async with Client(mcp) as client: # mode="auto" negotiates 2026-07-28 + assert client.session.protocol_version == "2026-07-28" + result = await client.call_tool("delete_folder", {"path": "/docs"}, allow_input_required=True) + assert isinstance(result, InputRequiredResult) + assert result.input_requests is not None + (request,) = result.input_requests.values() + assert request.method == "elicitation/create" + assert "/docs has 2 file(s)" in request.params.message + assert result.request_state is not None + assert "/docs" in fs # nothing deleted before the answer arrives + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("action", "content", "expected"), + [ + ("accept", {"ok": True}, "deleted /docs"), + ("accept", {"ok": False}, "kept the folder"), + ("decline", None, "declined: folder not deleted"), + ("cancel", None, "cancelled: folder not deleted"), + ], +) +async def test_input_required_loop_handles_every_outcome( + action: Literal["accept", "decline", "cancel"], + content: dict[str, str | int | float | bool | list[str] | None] | None, + expected: str, +): + mcp, fs = _delete_folder_server() + fs["/docs"] = ["a.txt", "b.txt"] + + def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + assert "/docs has 2 file(s)" in params.message + return ElicitResult(action=action, content=content) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "delete_folder", {"path": "/docs"}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == expected + assert ("/docs" in fs) is (expected != "deleted /docs") + + +@pytest.mark.anyio +async def test_input_required_empty_folder_completes_in_one_round(): + mcp, fs = _delete_folder_server() + fs["/empty"] = [] + + def never(key: str, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit for an empty folder") + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "delete_folder", {"path": "/empty"}, never) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "deleted /empty" + assert "/empty" not in fs + + +@pytest.mark.anyio +async def test_input_required_resolver_asks_and_consumes_then_never_reruns(): + mcp = MCPServer(name="ExactlyOnceMRTR") + counts = {"login": 0, "confirm": 0} + + async def login(ctx: Context) -> Login | Elicit[Login]: + counts["login"] += 1 + return Elicit("Username?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + counts["confirm"] += 1 + return Elicit(f"As {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + if "Username" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + # An eliciting resolver runs twice - once to ask, once to consume the answer - + # then its outcome is carried in `request_state` and it never runs again. `login` + # asks in round 1 and is consumed in round 2; `confirm` (which depends on + # `login`) only forms its question once `login` is known, so it asks in round 2 + # and is consumed in round 3. Neither re-runs beyond consuming its own answer. + assert counts == {"login": 2, "confirm": 2} + + +@pytest.mark.anyio +async def test_input_required_batches_independent_elicits_in_one_round(): + mcp = MCPServer(name="BatchedMRTR") + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + async def ask_confirm(ctx: Context) -> Elicit[Confirm]: + return Elicit("Confirm?", Confirm) + + @mcp.tool() + async def both( + name: Annotated[Login, Resolve(ask_name)], + confirm: Annotated[Confirm, Resolve(ask_confirm)], + ) -> str: + return f"{name.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + if "Name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + # Both independent resolvers are asked together in the first round. + first = await client.call_tool("both", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 + + result = await _drive_mrtr(client, "both", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +def test_uses_input_required_version_gate(): + assert uses_input_required("2026-07-28") is True + assert uses_input_required("2025-11-25") is False + assert uses_input_required(None) is False + + +@pytest.mark.parametrize( + "request_state", + [ + None, + "", + "not json", + '{"v": 99, "outcomes": {}}', # wrong version + '{"v": 1}', # missing outcomes + '{"v": 1, "outcomes": []}', # outcomes not a dict + "[1, 2, 3]", # not an object + ], +) +def test_decode_state_tolerates_malformed_request_state(request_state: str | None): + assert _decode_state(request_state) == {} + + +def test_state_round_trips_accept_decline_cancel(): + outcomes: dict[str, ElicitationResult[BaseModel]] = { + "a": AcceptedElicitation(data=Login(username="octocat")), + "b": DeclinedElicitation(), + "c": CancelledElicitation(), + "d": AcceptedElicitation.model_construct(data="raw-token"), # non-model value + } + decoded = _decode_state(_encode_state(outcomes)) + + accepted = _outcome_from_state(decoded["a"], Login) + assert isinstance(accepted, AcceptedElicitation) and accepted.data == Login(username="octocat") + assert isinstance(_outcome_from_state(decoded["b"], None), DeclinedElicitation) + assert isinstance(_outcome_from_state(decoded["c"], None), CancelledElicitation) + raw = _outcome_from_state(decoded["d"], None) + assert isinstance(raw, AcceptedElicitation) and raw.data == "raw-token" + + +def test_elicit_return_schema_extraction(): + async def with_elicit(ctx: Context) -> Login | Elicit[Login]: + return Elicit("?", Login) # pragma: no cover + + async def without_elicit(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + assert _elicit_return_schema(Login | Elicit[Login]) is Login + assert _elicit_return_schema(Login) is None + assert _elicit_return_schema(None) is None + + +@pytest.mark.anyio +async def test_non_elicitation_response_raises(): + from mcp_types import CreateMessageResult, TextContent + + mcp = MCPServer(name="WrongResponse") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + @mcp.tool() + async def tool(name: Annotated[Login, Resolve(ask)]) -> str: + return name.username # pragma: no cover + + async with Client(mcp) as client: + r1 = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(r1, InputRequiredResult) + assert r1.input_requests is not None + (key,) = r1.input_requests + # Answer with a sampling result instead of an elicitation result. + r2 = await client.call_tool( + "tool", + {}, + input_responses={ + key: CreateMessageResult(role="assistant", content=TextContent(type="text", text="x"), model="m") + }, + request_state=r1.request_state, + allow_input_required=True, + ) + assert isinstance(r2, CallToolResult) + assert r2.is_error + assert isinstance(r2.content[0], TextContent) + assert "non-elicitation response" in r2.content[0].text From b5d8d1e9d1e858c27a1a9645ccff6cf5c632082c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 11:50:14 +0200 Subject: [PATCH 2/5] Remove casts from the input_required resolver path Replace every cast() with a checked or properly-typed alternative: model the request_state payload with pydantic (_State/_StateEntry) so the untrusted JSON is validated instead of cast; type _Resolution.pending as InputRequests so an ElicitRequest fits without a cast; add a _is_elicit TypeGuard and an _accepted helper that carry the right types; and narrow req.params via isinstance in the test helper. No behavior change. --- src/mcp/server/mcpserver/resolve.py | 95 +++++++++++++++----------- tests/server/mcpserver/test_resolve.py | 18 ++--- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 2bc6e15d4..b9a378c17 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -26,10 +26,9 @@ from __future__ import annotations import inspect -import json import typing from collections.abc import Callable, Hashable, Mapping -from typing import Annotated, Any, Generic, cast, get_args, get_origin +from typing import Annotated, Any, Generic, Literal, TypeGuard, get_args, get_origin import anyio.to_thread from mcp_types import ( @@ -41,7 +40,7 @@ InputResponses, ) from mcp_types.version import LATEST_MODERN_VERSION, is_version_at_least -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar from mcp.server.elicitation import ( @@ -298,7 +297,7 @@ def __init__( self.answers: InputResponses = context.input_responses or {} if input_required else {} self.state = _decode_state(context.request_state) if input_required else {} self.cache: dict[str, ElicitationResult[Any]] = {} - self.pending: dict[str, ElicitRequest] = {} + self.pending: InputRequests = {} def _state_key(fn: Callable[..., Any]) -> str: @@ -342,10 +341,7 @@ async def resolve_arguments( injected[name] = outcome if wants_union else _unwrap(outcome, name) if res.pending: - return InputRequiredResult( - input_requests=cast("InputRequests", res.pending), - request_state=_encode_state(res.cache), - ) + return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.cache)) return injected @@ -379,23 +375,24 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul dep_outcome = await _resolve(param_plan.resolve.fn, res) kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + result: Any if plan.is_async: result = await fn(**kwargs) else: result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) - if isinstance(result, Elicit): - outcome = await _elicit(cast("Elicit[BaseModel]", result), key, res) + if _is_elicit(result): + outcome = await _elicit(result, key, res) else: - # A resolver may return any type (not just `BaseModel`); `model_construct` - # wraps it as an accepted result without validating against the schema bound. - outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) + # A resolver may return any type (not just `BaseModel`), so accept it as the + # outcome without validating against the schema bound. + outcome = _accepted(result) res.cache[key] = outcome return outcome -async def _elicit(elicit: Elicit[BaseModel], key: str, res: _Resolution) -> ElicitationResult[Any]: +async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> ElicitationResult[Any]: """Turn a resolver's `Elicit` into an outcome via the negotiated transport.""" if not res.input_required: return await res.context.elicit(elicit.message, elicit.schema) @@ -419,6 +416,20 @@ def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") +def _is_elicit(value: Any) -> TypeGuard[Elicit[Any]]: + """Runtime narrow of a resolver's return value to a (parameter-erased) `Elicit`.""" + return isinstance(value, Elicit) + + +def _accepted(data: Any) -> AcceptedElicitation[Any]: + """Wrap a resolved value as an accepted outcome without schema validation. + + A resolver may return any type (the schema bound only constrains `Elicit[T]`), + and a value restored from `request_state` is already validated. + """ + return AcceptedElicitation[Any].model_construct(data=data) + + def uses_input_required(protocol_version: str | None) -> bool: """True when this request must elicit via `InputRequiredResult` (>= 2026-07-28). @@ -434,50 +445,56 @@ def _elicit_request(elicit: Elicit[Any]) -> ElicitRequest: return ElicitRequest(params=ElicitRequestFormParams(message=elicit.message, requested_schema=json_schema)) -def _decode_state(request_state: str | None) -> dict[str, dict[str, Any]]: +class _StateEntry(BaseModel): + """One resolver's recorded outcome inside `request_state`.""" + + action: Literal["accept", "decline", "cancel"] + data: Any = None + + +class _State(BaseModel): + """The decoded `request_state`: resolver outcomes from earlier rounds.""" + + v: int + outcomes: dict[str, _StateEntry] = {} + + +def _decode_state(request_state: str | None) -> dict[str, _StateEntry]: """Decode the per-call resolution progress from `request_state`. - `request_state` is client-trusted (integrity sealing is a follow-up); decode - defensively and treat anything malformed as "no progress yet". + `request_state` is client-trusted (integrity sealing is a follow-up); validate + it through `_State` and treat anything malformed as "no progress yet". """ if not request_state: return {} try: - decoded: Any = json.loads(request_state) - except json.JSONDecodeError: - return {} - if not isinstance(decoded, dict): - return {} - payload = cast("dict[str, Any]", decoded) - if payload.get("v") != _STATE_VERSION: + state = _State.model_validate_json(request_state) + except ValidationError: return {} - outcomes = payload.get("outcomes") - return cast("dict[str, dict[str, Any]]", outcomes) if isinstance(outcomes, dict) else {} + return state.outcomes if state.v == _STATE_VERSION else {} def _encode_state(outcomes: Mapping[str, ElicitationResult[Any]]) -> str: """Encode resolved outcomes (keyed by resolver path) for the next round.""" - encoded: dict[str, dict[str, Any]] = {} + entries: dict[str, _StateEntry] = {} for path, outcome in outcomes.items(): - entry: dict[str, Any] = {"action": outcome.action} - if isinstance(outcome, AcceptedElicitation): - data = outcome.data - entry["data"] = data.model_dump(mode="json") if isinstance(data, BaseModel) else data - encoded[path] = entry - return json.dumps({"v": _STATE_VERSION, "outcomes": encoded}) + data = outcome.data if isinstance(outcome, AcceptedElicitation) else None + if isinstance(data, BaseModel): + data = data.model_dump(mode="json") + entries[path] = _StateEntry(action=outcome.action, data=data) + return _State(v=_STATE_VERSION, outcomes=entries).model_dump_json() -def _outcome_from_state(entry: Mapping[str, Any], schema: type[BaseModel] | None) -> ElicitationResult[Any]: +def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> ElicitationResult[Any]: """Rebuild an `ElicitationResult` from a decoded `request_state` entry.""" - action = entry.get("action") - if action == "decline": + if entry.action == "decline": return DeclinedElicitation() - if action == "cancel": + if entry.action == "cancel": return CancelledElicitation() - data = entry.get("data") + data = entry.data if schema is not None and isinstance(data, dict): data = schema.model_validate(data) - return cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=data)) + return _accepted(data) __all__ = [ diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 61ec290d0..5ef3e7177 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,11 +1,12 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" from collections.abc import Callable -from typing import Annotated, Literal, cast +from typing import Annotated, Literal import pytest from mcp_types import ( CallToolResult, + ElicitRequestFormParams, ElicitRequestParams, ElicitResult, InputRequiredResult, @@ -73,7 +74,7 @@ async def _drive_mrtr( client: Client, tool: str, args: dict[str, object], - answer: Callable[[str, ElicitRequestParams], ElicitResult], + answer: Callable[[str, ElicitRequestFormParams], ElicitResult], max_rounds: int = 10, ) -> CallToolResult: """Drive the 2026-07-28 `input_required` loop to completion. @@ -92,9 +93,10 @@ async def _drive_mrtr( return result assert isinstance(result, InputRequiredResult) assert result.input_requests is not None - responses = { - key: answer(key, cast(ElicitRequestParams, req.params)) for key, req in result.input_requests.items() - } + responses = {} + for key, req in result.input_requests.items(): + assert isinstance(req.params, ElicitRequestFormParams) + responses[key] = answer(key, req.params) state = result.request_state raise AssertionError("input_required loop did not converge") # pragma: no cover @@ -626,7 +628,7 @@ async def test_input_required_loop_handles_every_outcome( mcp, fs = _delete_folder_server() fs["/docs"] = ["a.txt", "b.txt"] - def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: assert "/docs has 2 file(s)" in params.message return ElicitResult(action=action, content=content) @@ -672,7 +674,7 @@ async def act( ) -> str: return f"{login.username}:{confirm.ok}" - def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: if "Username" in params.message: return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) @@ -707,7 +709,7 @@ async def both( ) -> str: return f"{name.username}:{confirm.ok}" - def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: if "Name" in params.message: return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) From db12e6c5e3f65ab2a95d1573098c576c81cd3209 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 13:28:55 +0200 Subject: [PATCH 3/5] Fix resolver MRTR edge cases from review - In-call cache keyed by _resolver_key (instance-distinct) again; _state_key adds id(__self__) for the wire key, so two instances of one bound method no longer collide and silently share an outcome. - resolve_arguments reads ctx.protocol_version (new Context property, None outside a request) instead of dereferencing request_context, so direct MCPServer.call_tool() works for tools whose resolvers never elicit. - request_state persists only elicited outcomes (always validated models); a resolver that resolves without eliciting is pure and re-runs each round. Fixes the json.dumps crash on non-serializable returns (datetime/set/...) and the dict-degradation of restored values. - _elicit_return_schema handles a bare Elicit[T] return (not only unions). - _INPUT_REQUIRED_VERSION pinned to '2026-07-28' instead of LATEST_MODERN_VERSION. - accept with no content raises ToolError instead of silently reporting cancel. - Independent nested resolver deps batch into one round (catch _Pending per dep). - Test cleanup: drop dead helpers, hoist CreateMessageResult import. Add regression tests for each; document the narrowed elicited-only persistence. --- docs/migration.md | 2 +- src/mcp/server/mcpserver/context.py | 5 + src/mcp/server/mcpserver/resolve.py | 101 +++++++++----- tests/server/mcpserver/test_resolve.py | 186 +++++++++++++++++++++++-- 4 files changed, 250 insertions(+), 44 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index fb3784681..55bb37fdf 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1515,7 +1515,7 @@ async def delete_folder( The `confirm_delete` resolver reads the tool's own `path` argument by name, lists the folder, and only elicits when the folder is non-empty - an empty folder resolves to `Confirm(ok=True)` with no round-trip to the client. Because `delete_folder` annotates the result union, it handles every outcome: the user accepting and confirming, accepting but declining to delete (`ok=False`), declining the elicitation, or cancelling it. -The framework drives elicitation over whichever transport the negotiated protocol provides, so the resolver and tool code above is unchanged either way. At `2026-07-28` and later it returns an `InputRequiredResult` carrying the questions and resumes when the client retries `call_tool(..., input_responses=..., request_state=...)` (independent resolvers are batched into one round; a resolver that depends on another's answer is asked in a later round). At `2025-11-25` and earlier it issues a synchronous `elicitation/create` request mid-call. Resolved outcomes are carried in `request_state` across rounds so each resolver resolves once per call. +The framework drives elicitation over whichever transport the negotiated protocol provides, so the resolver and tool code above is unchanged either way. At `2026-07-28` and later it returns an `InputRequiredResult` carrying the questions and resumes when the client retries `call_tool(..., input_responses=..., request_state=...)` (independent resolvers are batched into one round; a resolver that depends on another's answer is asked in a later round). At `2025-11-25` and earlier it issues a synchronous `elicitation/create` request mid-call. Elicited answers are carried in `request_state` across rounds, so each question is asked once; a resolver that resolves without eliciting is pure and may re-run each round. Resolved parameters are omitted from the tool's input schema, so the client never supplies them. Resolver parameters that cannot be classified, and cyclic resolver dependencies, raise at registration time. diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 28daef1cd..c40700160 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -238,6 +238,11 @@ def request_id(self) -> str: """Get the unique ID for this request.""" return str(self.request_context.request_id) + @property + def protocol_version(self) -> str | None: + """The negotiated protocol version, or `None` outside of an active request.""" + return self._request_context.protocol_version if self._request_context is not None else None + @property def input_responses(self) -> InputResponses | None: """Client responses to a prior `InputRequiredResult.input_requests`. diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index b9a378c17..1986adf4a 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -12,8 +12,9 @@ and resumes when the client retries with `input_responses`/`request_state` (independent resolvers are asked in one round; a resolver depending on another's answer is asked in a later round). At <= 2025-11-25 it issues a synchronous -`elicitation/create` request mid-call. Resolved outcomes are carried in -`request_state` across rounds, so each resolver resolves once per logical call. +`elicitation/create` request mid-call. Only *elicited* outcomes are carried in +`request_state` across rounds (so the user is asked each question once); a +resolver that returns a value without eliciting is pure and may re-run each round. Whether the consumer receives the unwrapped model or the full `ElicitationResult` union is decided by the consumer's annotation: @@ -26,6 +27,7 @@ from __future__ import annotations import inspect +import types import typing from collections.abc import Callable, Hashable, Mapping from typing import Annotated, Any, Generic, Literal, TypeGuard, get_args, get_origin @@ -39,7 +41,7 @@ InputRequiredResult, InputResponses, ) -from mcp_types.version import LATEST_MODERN_VERSION, is_version_at_least +from mcp_types.version import is_version_at_least from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar @@ -61,7 +63,8 @@ # First protocol revision whose `tools/call` carries elicitation inside # `InputRequiredResult` rather than as a standalone server-to-client request. -_INPUT_REQUIRED_VERSION = LATEST_MODERN_VERSION # "2026-07-28" +# Pinned (not `LATEST_MODERN_VERSION`, which moves when newer revisions are added). +_INPUT_REQUIRED_VERSION = "2026-07-28" _STATE_VERSION = 1 @@ -158,10 +161,12 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. - Lets an outcome restored from `request_state` (a plain dict) be re-validated - into its model so dependent resolvers and tools receive a typed value. + Handles a bare `-> Elicit[T]` and a `-> T | Elicit[T]` union. Lets an elicited + outcome restored from `request_state` (a plain dict) be re-validated into its + model so dependent resolvers and tools receive a typed value. """ - candidates = get_args(return_annotation) if get_origin(return_annotation) is not None else (return_annotation,) + # A bare `Elicit[T]` is itself a candidate; a union contributes its members. + candidates = get_args(return_annotation) if _is_union(return_annotation) else (return_annotation,) for candidate in candidates: if get_origin(candidate) is Elicit: schema = get_args(candidate)[0] @@ -170,6 +175,10 @@ def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: return None +def _is_union(annotation: Any) -> bool: + return get_origin(annotation) in (typing.Union, types.UnionType) + + def _wants_union(type_arg: Any) -> bool: """True when `type_arg` is an `ElicitationResult` member (or a union of them). @@ -296,19 +305,26 @@ def __init__( self.input_required = input_required self.answers: InputResponses = context.input_responses or {} if input_required else {} self.state = _decode_state(context.request_state) if input_required else {} - self.cache: dict[str, ElicitationResult[Any]] = {} + # In-call dedup keyed by resolver identity (distinguishes two instances of + # the same bound method); `elicited` holds only outcomes that came from an + # elicitation, keyed by their wire key - these are what `request_state` + # persists, since pure resolvers are cheap to re-run each round. + self.cache: dict[Hashable, ElicitationResult[Any]] = {} + self.elicited: dict[str, ElicitationResult[Any]] = {} self.pending: InputRequests = {} def _state_key(fn: Callable[..., Any]) -> str: - """Process-stable wire key for a resolver. + """Process-stable wire key for a resolver's elicitation. - `id`-based keys aren't stable across `input_required` rounds (a retry may land - on a different worker), so memoize and key `input_requests`/`request_state` by - the resolver's `module:qualname`. Two consumers of the same resolver therefore - share one cache entry, one question, and one stored outcome. + `id(fn)` isn't stable across `input_required` rounds, so key `input_requests` / + `request_state` by `module:qualname`. Bound methods add their `__self__` id so + two instances of the same method get distinct questions and stored outcomes + (the registered `Resolve(...)` holds the instance for the call's lifetime). """ - return f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)}" + base = f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)!s}" + bound_self = getattr(fn, "__self__", None) + return f"{base}#{id(bound_self)}" if bound_self is not None else base async def resolve_arguments( @@ -324,14 +340,19 @@ async def resolve_arguments( negotiated protocol is >= 2026-07-28), returns an `InputRequiredResult` carrying the batched questions instead; the tool body is not run. - Each resolver runs at most once per logical call - across multiple - `input_required` rounds, resolved outcomes are carried in `request_state`. + An eliciting resolver asks its question once - its answer is carried in + `request_state` across rounds - while a resolver that resolves without + eliciting is pure and may re-run on each round. Raises: ToolError: If an elicited value is declined or cancelled and the consumer asked for the unwrapped model (rather than the result union). """ - res = _Resolution(plans, tool_args, context, uses_input_required(context.request_context.protocol_version)) + # `ctx.protocol_version` is `None` outside an active request: `MCPServer.call_tool()` + # called directly builds such a `Context`, and a tool whose resolvers never elicit + # must still work there. A missing version means the synchronous (non-input_required) + # transport, which never reaches a server-to-client request anyway. + res = _Resolution(plans, tool_args, context, uses_input_required(context.protocol_version)) injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): try: @@ -341,30 +362,32 @@ async def resolve_arguments( injected[name] = outcome if wants_union else _unwrap(outcome, name) if res.pending: - return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.cache)) + return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.elicited)) return injected async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResult[Any]: - """Resolve one resolver, memoized by its process-stable state key. + """Resolve one resolver, deduped within the call by its resolver identity. Raises `_Pending` when the resolver (or one of its dependencies) needs client input that has not arrived yet. """ - key = _state_key(fn) - if key in res.cache: - return res.cache[key] - if key in res.pending: + cache_key = _resolver_key(fn) + if cache_key in res.cache: + return res.cache[cache_key] + + plan = res.plans[cache_key] + wire_key = _state_key(fn) + if wire_key in res.pending: # Already asked this round by another consumer; don't run the resolver again. raise _Pending - - plan = res.plans[_resolver_key(fn)] - if key in res.state: - outcome = _outcome_from_state(res.state[key], plan.elicit_schema) - res.cache[key] = outcome + if wire_key in res.state: + outcome = _outcome_from_state(res.state[wire_key], plan.elicit_schema) + res.cache[cache_key] = outcome return outcome kwargs: dict[str, Any] = {} + dep_pending = False for param_name, param_plan in plan.params.items(): if param_plan.kind == "context": kwargs[param_name] = res.context @@ -372,8 +395,16 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul kwargs[param_name] = res.tool_args[param_name] else: assert param_plan.resolve is not None - dep_outcome = await _resolve(param_plan.resolve.fn, res) + try: + # Visit every dependency so independent ones that need input are all + # collected into `res.pending` and batched into a single round. + dep_outcome = await _resolve(param_plan.resolve.fn, res) + except _Pending: + dep_pending = True + continue kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + if dep_pending: + raise _Pending result: Any if plan.is_async: @@ -382,13 +413,15 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) if _is_elicit(result): - outcome = await _elicit(result, key, res) + outcome = await _elicit(result, wire_key, res) + res.elicited[wire_key] = outcome else: # A resolver may return any type (not just `BaseModel`), so accept it as the - # outcome without validating against the schema bound. + # outcome without validating against the schema bound. Plain outcomes are not + # persisted in `request_state`; the resolver re-runs next round instead. outcome = _accepted(result) - res.cache[key] = outcome + res.cache[cache_key] = outcome return outcome @@ -403,7 +436,9 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio raise _Pending if not isinstance(answer, ElicitResult): raise ToolError(f"Resolver {key!r} received a non-elicitation response") - if answer.action == "accept" and answer.content is not None: + if answer.action == "accept": + if answer.content is None: + raise ToolError(f"Resolver {key!r} received an accepted elicitation with no content") return AcceptedElicitation(data=elicit.schema.model_validate(answer.content)) if answer.action == "decline": return DeclinedElicitation() diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 5ef3e7177..ef0c1c92f 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -6,6 +6,7 @@ import pytest from mcp_types import ( CallToolResult, + CreateMessageResult, ElicitRequestFormParams, ElicitRequestParams, ElicitResult, @@ -766,21 +767,14 @@ def test_state_round_trips_accept_decline_cancel(): def test_elicit_return_schema_extraction(): - async def with_elicit(ctx: Context) -> Login | Elicit[Login]: - return Elicit("?", Login) # pragma: no cover - - async def without_elicit(ctx: Context) -> Login: - return Login(username="x") # pragma: no cover - - assert _elicit_return_schema(Login | Elicit[Login]) is Login - assert _elicit_return_schema(Login) is None + assert _elicit_return_schema(Elicit[Login]) is Login # bare Elicit[T] + assert _elicit_return_schema(Login | Elicit[Login]) is Login # union arm + assert _elicit_return_schema(Login) is None # no Elicit arm assert _elicit_return_schema(None) is None @pytest.mark.anyio async def test_non_elicitation_response_raises(): - from mcp_types import CreateMessageResult, TextContent - mcp = MCPServer(name="WrongResponse") async def ask(ctx: Context) -> Elicit[Login]: @@ -809,3 +803,175 @@ async def tool(name: Annotated[Login, Resolve(ask)]) -> str: assert r2.is_error assert isinstance(r2.content[0], TextContent) assert "non-elicitation response" in r2.content[0].text + + +@pytest.mark.anyio +async def test_direct_call_tool_with_non_eliciting_resolver(): + # `MCPServer.call_tool()` called directly builds a Context with no request, so + # `ctx.protocol_version` is None. A tool whose resolvers never elicit must still + # work there (regression: it used to raise "Context is not available"). + mcp = MCPServer(name="Direct") + + async def whoami(ctx: Context) -> Login: + return Login(username="direct") + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(whoami)]) -> str: + return login.username + + result = await mcp.call_tool("tool", {}, Context(mcp_server=mcp)) + assert isinstance(result, CallToolResult) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "direct" + + +@pytest.mark.anyio +async def test_two_instances_of_one_method_do_not_collide(): + mcp = MCPServer(name="Instances") + + class Service: + def __init__(self, name: str) -> None: + self.name = name + + async def who(self, ctx: Context) -> Login: + return Login(username=self.name) + + alice, bob = Service("alice"), Service("bob") + + @mcp.tool() + async def both( + a: Annotated[Login, Resolve(alice.who)], + b: Annotated[Login, Resolve(bob.who)], + ) -> str: + return f"{a.username},{b.username}" + + result = await mcp.call_tool("both", {}, Context(mcp_server=mcp)) + assert isinstance(result, CallToolResult) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "alice,bob" + + +@pytest.mark.anyio +async def test_non_serializable_sibling_resolver_does_not_break_rounds(): + from datetime import datetime + + mcp = MCPServer(name="NonSerializable") + + async def clock(ctx: Context) -> datetime: + return datetime(2026, 1, 1) + + async def ask(ctx: Context) -> Elicit[Confirm]: + return Elicit("ok?", Confirm) + + @mcp.tool() + async def act( + when: Annotated[datetime, Resolve(clock)], + confirm: Annotated[Confirm, Resolve(ask)], + ) -> str: + return f"{when.year}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "2026:True" + + +@pytest.mark.anyio +async def test_bare_elicit_dependency_restored_as_model(): + # A `-> Elicit[Login]` (bare, no union) resolver feeds a dependent resolver. After + # the round-trip the dependency must come back as a Login model, not a raw dict. + mcp = MCPServer(name="BareElicitDep") + + async def login(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"as {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "user" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "as octocat?" in params.message # proves login was a real model + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_accept_with_no_content_is_an_error_not_a_cancel(): + mcp = MCPServer(name="AcceptNoContent") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username # pragma: no cover + + async with Client(mcp) as client: + r1 = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(r1, InputRequiredResult) + assert r1.input_requests is not None + (key,) = r1.input_requests + r2 = await client.call_tool( + "tool", + {}, + input_responses={key: ElicitResult(action="accept", content=None)}, + request_state=r1.request_state, + allow_input_required=True, + ) + assert isinstance(r2, CallToolResult) + assert r2.is_error + assert isinstance(r2.content[0], TextContent) + assert "no content" in r2.content[0].text + + +@pytest.mark.anyio +async def test_independent_nested_deps_batch_into_one_round(): + mcp = MCPServer(name="NestedBatch") + + async def ask_a(ctx: Context) -> Elicit[Login]: + return Elicit("A name?", Login) + + async def ask_b(ctx: Context) -> Elicit[Confirm]: + return Elicit("B confirm?", Confirm) + + # `combine` depends on two independent eliciting resolvers; both must be asked + # in the same round, not serialized across two InputRequiredResult rounds. + async def combine( + a: Annotated[Login, Resolve(ask_a)], + b: Annotated[Confirm, Resolve(ask_b)], + ) -> Login: + return Login(username=f"{a.username}:{b.ok}") + + @mcp.tool() + async def tool(combined: Annotated[Login, Resolve(combine)]) -> str: + return combined.username + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + first = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 # batched, not serialized + + result = await _drive_mrtr(client, "tool", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" From ecb54b3e5554a052fb98c3a8fe654bb3e6dc6946 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 13:38:26 +0200 Subject: [PATCH 4/5] Fix two MRTR state bugs found by Codex review - Carry restored answers forward: an elicited outcome restored from request_state is now re-added to res.elicited, so in a 4+-round dependency chain an early answer is not dropped from request_state and re-asked on a later round. - Collision-free wire keys: assign each resolver a deterministic wire key at registration (module:qualname, disambiguated with #N when bases collide), so two distinct closures from one factory get separate questions/outcomes instead of sharing one. _state_key is now only the base-key source at registration. Add regression tests: a deep chain asserting an early answer is asked once, and factory closures asserting distinct wire keys and correct injected values. --- src/mcp/server/mcpserver/resolve.py | 23 +++++++- tests/server/mcpserver/test_resolve.py | 77 ++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 1986adf4a..383963d68 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -109,6 +109,7 @@ def __init__( params: dict[str, _ParamPlan], is_async: bool, elicit_schema: type[BaseModel] | None, + wire_key: str, ) -> None: self.fn = fn self.params = params @@ -116,6 +117,10 @@ def __init__( # The `T` from the resolver's `Elicit[T]` return arm, if annotated. Used to # re-validate an outcome restored from `request_state` into a model. self.elicit_schema = elicit_schema + # Deterministic, collision-free key for this resolver's elicitation on the + # wire (`input_requests`/`request_state`). Assigned at registration so it is + # stable across rounds even when `module:qualname` collides (closures). + self.wire_key = wire_key def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: @@ -226,6 +231,9 @@ def build_resolver_plans( or a tool argument by name). """ plans: dict[Hashable, _ResolverPlan] = {} + # Count how many distinct resolvers share each `module:qualname` base so closures + # from one factory get distinct, deterministic wire keys (`base`, `base#1`, ...). + base_counts: dict[str, int] = {} def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: key = _resolver_key(fn) @@ -234,6 +242,11 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: if key in plans: return + base = _state_key(fn) + seen = base_counts.get(base, 0) + base_counts[base] = seen + 1 + wire_key = base if seen == 0 else f"{base}#{seen}" + hints = _type_hints(fn) sig = inspect.signature(fn) params: dict[str, _ParamPlan] = {} @@ -256,7 +269,9 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) - plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), _elicit_return_schema(hints.get("return"))) + plans[key] = _ResolverPlan( + fn, params, is_async_callable(fn), _elicit_return_schema(hints.get("return")), wire_key + ) for dep in nested: analyze(dep, stack + (key,)) @@ -377,13 +392,17 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul return res.cache[cache_key] plan = res.plans[cache_key] - wire_key = _state_key(fn) + wire_key = plan.wire_key if wire_key in res.pending: # Already asked this round by another consumer; don't run the resolver again. raise _Pending if wire_key in res.state: outcome = _outcome_from_state(res.state[wire_key], plan.elicit_schema) res.cache[cache_key] = outcome + # Carry the restored answer forward: if a later resolver is still pending, + # the next round's `request_state` is built from `res.elicited`, so an + # earlier answer must stay there or it would be dropped and re-asked. + res.elicited[wire_key] = outcome return outcome kwargs: dict[str, Any] = {} diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index ef0c1c92f..2f9742cad 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -975,3 +975,80 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: result = await _drive_mrtr(client, "tool", {}, answer) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_deep_chain_keeps_early_answers_across_rounds(): + # A 4-round dependency chain where an early answer (A) must survive in + # request_state while later resolvers are asked. It must be asked exactly once. + mcp = MCPServer(name="DeepChain") + + async def ra(ctx: Context) -> Elicit[Login]: + return Elicit("A name?", Login) + + async def rb(a: Annotated[Login, Resolve(ra)]) -> Elicit[Confirm]: + return Elicit("B?", Confirm) + + async def rc(b: Annotated[Confirm, Resolve(rb)]) -> Elicit[Confirm]: + return Elicit("C?", Confirm) + + async def rd(c: Annotated[Confirm, Resolve(rc)]) -> Elicit[Confirm]: + return Elicit("D?", Confirm) + + # Depends on `ra` directly AND on `rd` (which transitively needs ra->rb->rc). + @mcp.tool() + async def tool( + a: Annotated[Login, Resolve(ra)], + d: Annotated[Confirm, Resolve(rd)], + ) -> str: + return f"{a.username}:{d.ok}" + + a_asks = 0 + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + nonlocal a_asks + if "name" in params.message: + a_asks += 1 + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "tool", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + assert a_asks == 1 # ra's answer survived in request_state; never re-asked + + +@pytest.mark.anyio +async def test_factory_closures_get_distinct_wire_keys(): + # Two resolvers from one factory share module:qualname; they must still get + # distinct questions and their own values (regression: they collided on the wire). + mcp = MCPServer(name="FactoryClosures") + + def make(label: str): + async def resolver(ctx: Context) -> Elicit[Login]: + return Elicit(f"{label}?", Login) + + return resolver + + ask_a, ask_b = make("A"), make("B") + + @mcp.tool() + async def tool( + a: Annotated[Login, Resolve(ask_a)], + b: Annotated[Login, Resolve(ask_b)], + ) -> str: + return f"{a.username},{b.username}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + return ElicitResult(action="accept", content={"username": params.message[0]}) + + async with Client(mcp) as client: + first = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 # distinct keys, not collapsed to one + + result = await _drive_mrtr(client, "tool", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "A,B" From d89110ddd8283b9a8ccddd15f58f612fd1ef601c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 17:21:47 +0200 Subject: [PATCH 5/5] Worker-stable wire keys; restore typed model for unannotated eliciting resolvers Review follow-ups on #2986: - _state_key carries no id(...): it is module:qualname (a callable object uses its type's), so request_state round-trips and resumes on any worker (stateless HTTP). Two resolvers sharing that base (method instances, factory closures) are already disambiguated deterministically at registration (#N), so dropping the id is safe. - An eliciting resolver whose annotation lacks an Elicit[T] arm has elicit_schema None; its answer restored from request_state is now re-validated against the live Elicit.schema (via _elicit consulting res.state) instead of injecting a raw dict. - Move the datetime import to module top (AGENTS.md). Add regression tests: an unannotated eliciting resolver in a multi-round flow, and worker-stable wire keys for method instances and callable objects. --- src/mcp/server/mcpserver/resolve.py | 30 +++++++++---- tests/server/mcpserver/test_resolve.py | 58 +++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 11 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index b7649a0e5..870a21024 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -345,16 +345,17 @@ def __init__( def _state_key(fn: Callable[..., Any]) -> str: - """Process-stable wire key for a resolver's elicitation. + """Worker-stable base wire key for a resolver, derived only from registration data. - `id(fn)` isn't stable across `input_required` rounds, so key `input_requests` / - `request_state` by `module:qualname`. Bound methods add their `__self__` id so - two instances of the same method get distinct questions and stored outcomes - (the registered `Resolve(...)` holds the instance for the call's lifetime). + `input_requests`/`request_state` must round-trip through the client and resume on + any worker (stateless HTTP), so the key carries no `id(...)`: it is the resolver's + `module:qualname` (a callable object uses its type's). Distinct resolvers that + share this base - two instances of one method, two closures from one factory - are + disambiguated deterministically by `build_resolver_plans` (`base`, `base#1`, ...). """ - base = f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)!s}" - bound_self = getattr(fn, "__self__", None) - return f"{base}#{id(bound_self)}" if bound_self is not None else base + qualname = getattr(fn, "__qualname__", None) or type(fn).__qualname__ + module = getattr(fn, "__module__", None) or type(fn).__module__ + return f"{module}:{qualname}" async def resolve_arguments( @@ -411,7 +412,11 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul if wire_key in res.pending: # Already asked this round by another consumer; don't run the resolver again. raise _Pending - if wire_key in res.state: + # Restore a prior round's outcome directly only when its model is known from the + # `Elicit[T]` return arm. Without that (a resolver that elicits but isn't annotated + # `-> ... Elicit[T]`), fall through and re-run the resolver so `_elicit` can + # re-validate the stored answer against the live `Elicit.schema`. + if wire_key in res.state and (plan.elicit_schema is not None or res.state[wire_key].action != "accept"): outcome = _outcome_from_state(res.state[wire_key], plan.elicit_schema) res.cache[cache_key] = outcome # Carry the restored answer forward: if a later resolver is still pending, @@ -464,6 +469,13 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio if not res.input_required: return await res.context.elicit(elicit.message, elicit.schema) + # Answered in a prior round (restored without a known schema, e.g. an unannotated + # resolver): re-validate the stored entry against the live `Elicit.schema`. + if key in res.state and key not in res.answers: + outcome = _outcome_from_state(res.state[key], elicit.schema) + res.elicited[key] = outcome + return outcome + answer = res.answers.get(key) if answer is None: res.pending[key] = _elicit_request(elicit) diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 19d62cb9f..c7ab33d0f 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,6 +1,7 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" from collections.abc import Callable +from datetime import datetime from typing import Annotated, Any, Literal import pytest @@ -879,8 +880,6 @@ async def both( @pytest.mark.anyio async def test_non_serializable_sibling_resolver_does_not_break_rounds(): - from datetime import datetime - mcp = MCPServer(name="NonSerializable") async def clock(ctx: Context) -> datetime: @@ -1078,3 +1077,58 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: result = await _drive_mrtr(client, "tool", {}, answer) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "A,B" + + +@pytest.mark.anyio +async def test_eliciting_resolver_without_elicit_arm_restores_a_typed_model(): + # A resolver annotated `-> Login` that actually returns `Elicit(...)` has no + # `Elicit[T]` return arm, so `elicit_schema` is None. Its answer, restored from + # request_state in a 3+ round flow, must still come back as a Login model (not a + # raw dict) so a dependent resolver/tool can use its attributes. + mcp = MCPServer(name="LyingAnnotation") + + # Annotated without an `Elicit[T]` return arm, so `elicit_schema` is None. + async def login(ctx: Context) -> object: + return Elicit("user?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"as {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "user" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "as octocat?" in params.message # login restored as a real model + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +def test_wire_key_is_worker_stable_for_methods_and_callable_objects(): + from mcp.server.mcpserver.resolve import _state_key + + class Service: + async def token(self, ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + class CallableResolver: + async def __call__(self, ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + a, b = Service(), Service() + # No id(...) in the key: two instances of one method get the same base (they are + # disambiguated at registration, not here), and the key carries no memory address. + assert _state_key(a.token) == _state_key(b.token) + assert "#" not in _state_key(a.token) + assert _state_key(a.token).endswith("Service.token") + # Callable objects key by their type's qualname (they have no `__qualname__`). + assert _state_key(CallableResolver()).endswith("CallableResolver")