diff --git a/pyproject.toml b/pyproject.toml index d88869da1c..b98e64a487 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,6 +193,9 @@ strict-no-cover = { git = "https://github.com/pydantic/strict-no-cover" } [tool.pytest.ini_options] log_cli = true xfail_strict = true +markers = [ + "requirement(id): links a test to the entry in tests/interaction/_requirements.py it exercises", +] addopts = """ --color=yes --capture=fd diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360fa..b33fea4052 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -305,4 +305,4 @@ async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta async def send_roots_list_changed(self) -> None: """Send a notification that the roots list has changed.""" # TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support. - await self.session.send_roots_list_changed() # pragma: no cover + await self.session.send_roots_list_changed() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a77..86113874be 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -74,7 +74,7 @@ async def _default_elicitation_callback( context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( # pragma: no cover + return types.ErrorData( code=types.INVALID_REQUEST, message="Elicitation not supported", ) @@ -337,9 +337,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - from jsonschema import SchemaError, ValidationError, validate if result.structured_content is None: - raise RuntimeError( - f"Tool {name} has an output schema but did not return structured content" - ) # pragma: no cover + raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") try: validate(result.structured_content, output_schema) except ValidationError as e: @@ -408,7 +406,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None return result - async def send_roots_list_changed(self) -> None: # pragma: no cover + async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) @@ -449,7 +447,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): # pragma: no cover + case types.PingRequest(): with responder: return await responder.respond(types.EmptyResult()) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c6338..aa3e50e07e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -210,7 +210,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: # Stream ended normally (server closed) - reset attempt counter attempt = 0 - except Exception: # pragma: lax no cover + except Exception: logger.debug("GET stream error", exc_info=True) attempt += 1 @@ -267,8 +267,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Received 202 Accepted") return - if response.status_code == 404: # pragma: no branch - if isinstance(message, JSONRPCRequest): # pragma: no branch + if response.status_code == 404: + if isinstance(message, JSONRPCRequest): error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) @@ -492,17 +492,17 @@ async def handle_request_async(): async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" - if not self.session_id: # pragma: lax no cover - return + if not self.session_id: + return # pragma: no cover try: headers = self._prepare_headers() response = await client.delete(self.url, headers=headers) - if response.status_code == 405: # pragma: lax no cover + if response.status_code == 405: logger.debug("Server does not allow session termination") - elif response.status_code not in (200, 204): # pragma: lax no cover - logger.warning(f"Session termination failed: {response.status_code}") + elif response.status_code not in (200, 204): + logger.warning(f"Session termination failed: {response.status_code}") # pragma: no cover except Exception as exc: # pragma: no cover logger.warning(f"Session termination failed: {exc}") diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..d1a15120af 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -349,12 +349,12 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - if self._session_manager is None: # pragma: no cover - raise RuntimeError( + if self._session_manager is None: + raise RuntimeError( # pragma: no cover "Session manager can only be accessed after calling streamable_http_app(). " "The session manager is created lazily to avoid unnecessary initialization." ) - return self._session_manager # pragma: no cover + return self._session_manager async def run( self, @@ -513,7 +513,7 @@ async def _handle_request( if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err)) - else: # pragma: no cover + else: response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") if isinstance(response, types.ErrorData) and span is not None: diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index e87388eee9..1441649808 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -94,7 +94,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes """ progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - if progress_token is None: # pragma: no cover + if progress_token is None: return await self.request_context.session.send_progress_notification( @@ -237,7 +237,7 @@ async def close_sse_stream(self) -> None: This is a no-op if not using StreamableHTTP transport with event_store. The callback is only available when event_store is configured. """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover + if self._request_context and self._request_context.close_sse_stream: # pragma: no branch await self._request_context.close_sse_stream() async def close_standalone_sse_stream(self) -> None: diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index e5b2af7d82..2f778eb514 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -185,5 +185,5 @@ async def render( raise ValueError(f"Could not convert prompt result to message: {msg}") return messages - except Exception as e: # pragma: no cover + except Exception as e: raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b3471163b7..ec2365810e 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -244,7 +244,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - return self._lowlevel_server.session_manager # pragma: no cover + return self._lowlevel_server.session_manager @overload def run(self, transport: Literal["stdio"] = ...) -> None: ... diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 20b640527a..fc2f97a9cb 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -223,7 +223,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( @@ -447,7 +447,7 @@ async def elicit_url( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) - async def send_ping(self) -> types.EmptyResult: # pragma: no cover + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( types.PingRequest(), @@ -479,11 +479,11 @@ async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) - async def send_tool_list_changed(self) -> None: # pragma: no cover + async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification(types.ToolListChangedNotification()) - async def send_prompt_list_changed(self) -> None: # pragma: no cover + async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification(types.PromptListChangedNotification()) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 48192ff612..3e5261896b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -116,15 +116,15 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover - if scope["type"] != "http": + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": # pragma: no cover logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: + if error_response: # pragma: no cover await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -190,13 +190,13 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: + if error_response: # pragma: no cover return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -225,7 +225,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: + except ValidationError as err: # pragma: no cover logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f14201857c..c85eeeeadf 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -179,7 +179,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -198,11 +198,11 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover the disconnect. """ writer = self._sse_stream_writers.pop(request_id, None) - if writer: + if writer: # pragma: no branch writer.close() # Also close and remove request streams - if request_id in self._request_streams: + if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() receive_stream.close() @@ -242,7 +242,7 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: self.close_sse_stream(request_id) async def close_standalone_stream_callback() -> None: # pragma: no cover @@ -293,7 +293,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -320,10 +320,10 @@ def _create_json_response( ) -> Response: """Create a JSON response from a JSONRPCMessage.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: lax no cover - response_headers.update(headers) + if headers: + response_headers.update(headers) # pragma: no cover - if self.mcp_session_id: # pragma: lax no cover + if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( @@ -344,7 +344,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: event_data["id"] = event_message.event_id return event_data @@ -374,7 +374,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -389,7 +389,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -421,7 +421,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: lax no cover + if not has_json: # pragma: no cover response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, @@ -469,7 +469,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -495,7 +495,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -579,7 +579,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): # pragma: lax no cover + async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -595,10 +595,10 @@ async def sse_writer(): # pragma: lax no cover # If response, remove from pending streams and close if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): break - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in SSE writer") finally: logger.debug("Closing SSE writer") @@ -628,14 +628,14 @@ async def sse_writer(): # pragma: lax no cover # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("SSE response error") await sse_stream_writer.aclose() await self._clean_up_memory_streams(request_id) finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -643,9 +643,9 @@ async def sse_writer(): # pragma: lax no cover INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: + if writer: # pragma: no cover await writer.send(Exception(err)) - return + return # pragma: no cover async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. @@ -661,7 +661,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -673,7 +673,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -683,11 +683,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -707,7 +707,7 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for event_message in standalone_stream_reader: # pragma: lax no cover + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -716,8 +716,8 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover - logger.exception("Error in standalone SSE writer") + except Exception: + logger.exception("Error in standalone SSE writer") # pragma: no cover finally: logger.debug("Closing standalone SSE writer") await self._clean_up_memory_streams(GET_STREAM_KEY) @@ -775,7 +775,7 @@ async def terminate(self) -> None: request_stream_keys = list(self._request_streams.keys()) # Close all request streams asynchronously - for key in request_stream_keys: # pragma: lax no cover + for key in request_stream_keys: await self._clean_up_memory_streams(key) # Clear the request streams dictionary immediately @@ -793,13 +793,13 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response = self._create_error_response( @@ -809,7 +809,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover + async def _validate_request_headers(self, request: Request, send: Send) -> bool: if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): @@ -818,7 +818,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -826,7 +826,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -851,11 +851,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -867,14 +867,14 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. """ event_store = self._event_store if not event_store: - return + return # pragma: no cover try: headers = { @@ -883,7 +883,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Get protocol version from header (already validated in _validate_protocol_version) @@ -921,10 +921,10 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay sender") # Create and start EventSourceResponse @@ -936,13 +936,13 @@ async def send_event(event_message: EventMessage) -> None: try: await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay response") finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error replaying events") response = self._create_error_response( "Error replaying events", @@ -993,7 +993,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( session_message.metadata is not None and isinstance( session_message.metadata, @@ -1009,7 +1009,7 @@ async def message_router(): # regardless of whether a client is connected # messages will be replayed on the re-connect event_id = None - if self._event_store: # pragma: lax no cover + if self._event_store: event_id = await self._event_store.store_event(request_stream_id, message) logger.debug(f"Stored {event_id} from {request_stream_id}") @@ -1020,14 +1020,14 @@ async def message_router(): except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) except anyio.ClosedResourceError: - if self._terminated: + if self._terminated: # pragma: lax no cover logger.debug("Read stream closed by client") else: logger.exception("Unexpected closure of read stream in message router") @@ -1041,8 +1041,8 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream_id in list(self._request_streams.keys()): # pragma: lax no cover - await self._clean_up_memory_streams(stream_id) + for stream_id in list(self._request_streams.keys()): + await self._clean_up_memory_streams(stream_id) # pragma: no cover self._request_streams.clear() # Clean up the read and write streams diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab6..39d434505c 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -173,7 +173,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0e..707d4b61dd 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,19 +40,19 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: + if not host: # pragma: no cover logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: + if host in self.settings.allowed_hosts: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -62,19 +62,19 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: + if not origin: # pragma: no cover return True # Check exact match first - if origin in self.settings.allowed_origins: + if origin in self.settings.allowed_origins: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None diff --git a/tests/interaction/README.md b/tests/interaction/README.md new file mode 100644 index 0000000000..ba08fa564e --- /dev/null +++ b/tests/interaction/README.md @@ -0,0 +1,199 @@ +# Interaction-model test suite + +This suite enumerates the MCP interaction model as end-to-end tests: one test per piece of +functionality, asserting the full client↔server round trip through the public API. It exists to +pin the SDK's observable behaviour — every request type, every notification direction, every +error plane — so that internal rewrites of the send/receive path can be proven equivalent by +running the suite before and after. + +```bash +uv run --frozen pytest tests/interaction/ +``` + +The whole suite is in-memory and event-driven; it runs in about a second. + +## Ground rules + +- **Public API only.** Tests drive a `Client` connected to a `Server` or `MCPServer`. Nothing + reaches into session internals, so the suite keeps working when those internals change. + `ClientSession` is used directly only for behaviours `Client` cannot express (skipping + initialization, requesting a non-default protocol version). +- **Pin current behaviour.** Every test passes against the current `main`, including behaviours + that diverge from the specification. A failing or xfailed test proves nothing about whether a + rewrite preserved behaviour; a passing test that pins the wrong output exactly does. Known + divergences are recorded as data on the requirement (see below), not worked around in the test. +- **Spec-mandated assertions, not implementation quirks.** Error *codes* are asserted against + the constants in `mcp.types`; error *message strings* are pinned only where they are the + SDK's own deliberate output. +- **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that + could hang is bounded by `anyio.fail_after(5)`. The streamable HTTP tests drive the Starlette + app in-process through the suite's streaming ASGI bridge (`transports/_bridge.py`), which + delivers each response chunk as the server produces it — full duplex, but still no sockets, + threads, or subprocesses anywhere. + +## Layout + +```text +tests/interaction/ + _requirements.py the requirements manifest (see below) + _helpers.py shared type aliases + the wire-recording transport + _connect.py the transport-parametrized connection factories + conftest.py the connect fixture (the transport matrix) + test_coverage.py enforces the manifest ↔ test contract + lowlevel/ one file per feature area, against the low-level Server + mcpserver/ the same feature areas in MCPServer's natural idiom + transports/ behaviour specific to one transport (modes, streams, framing) +``` + +The two server APIs produce genuinely different wire output for the same conceptual feature +(`MCPServer` generates schemas, converts exceptions to `isError` results, attaches structured +content), so they get parallel directories with mirrored file names rather than one parametrized +test body — each directory pins its flavour's true output exactly. + +### The transport matrix + +Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)` +directly, and therefore run once per transport: over the in-memory transport and over the +server's real streamable HTTP app driven in process through the streaming bridge. A test connects +the same way in either case — `async with connect(server, ...) as client:` — and asserts the same +output, because the transport is not supposed to change observable behaviour. Tests that are tied +to one transport do not use the fixture: the wire-recording tests (their seam is the in-memory +stream pair), the bare-`ClientSession` lifecycle tests, the real-clock timeout tests (the timeout +machinery is transport-independent and must not race transport latency), and everything under +`transports/`, which pins behaviour only observable on that transport. + +A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app +**only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes, +response headers, SSE event fields, which stream a message travels on. Any other behaviour is +asserted through a `Client`, connected to the mounted app via `client_via_http(http)` so several +clients can share one session manager. + +## The requirements manifest + +`_requirements.py` maps every behaviour the suite covers to the reason it must hold: + +```python +"tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content.", +), +``` + +- **`source`** is a deep link into the MCP specification for externally mandated behaviour, + the literal string `"sdk"` for behaviour the SDK chose where the spec is silent, or + `"issue:#n"` for a regression lock. +- **`behavior`** describes the *required* behaviour — what the specification (or the SDK's own + contract) says should happen. Tests always pin the SDK's current behaviour; where that falls + short of `behavior`, the gap is recorded as data rather than hidden in the test. +- **`divergence`** records that gap for entries whose tests pin the divergent current behaviour. +- **`deferred`** marks a behaviour that is tracked but not yet covered by a test in this suite. + The reason names the covering tests elsewhere in the repo, starts with "Not implemented in the + SDK" for genuine feature gaps, or starts with "Not yet covered here" for tests that are planned. +- **`transports`** names the transports a behaviour applies to; omitted means transport-independent. +- **`issue`** carries the tracking link for a recorded gap once one is filed. + +Tests link themselves to the manifest with a decorator: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: ... +``` + +`test_coverage.py` enforces the contract in both directions: every non-deferred requirement must +be exercised by at least one test, every deferred requirement by none, and an unknown ID fails at +import time. A behaviour without a manifest entry cannot be silently half-tested, and a manifest +entry without a test cannot be silently aspirational. + +### The divergence lifecycle + +1. A test reveals that the SDK does not do what the spec says. The test pins what the SDK + *actually does* and a `Divergence(note=..., issue=...)` goes on the requirement. +2. When the behaviour is eventually fixed, the pinned test fails. Whoever makes the change finds + the divergence note explaining that the old behaviour was a known gap, re-pins the test to the + spec-correct output, and deletes the `Divergence`. +3. An empty divergence list means the SDK is spec-conformant on every behaviour the suite covers. + +A requirement may carry both `divergence` and `deferred`: the divergence records that the SDK falls +short of the spec, and the deferral records why no test pins it (typically because the divergent +behaviour cannot be driven through the public API). Divergence alone implies a test pins the +divergent behaviour; divergence plus deferred means the gap is known but unpinned. + +This is also the triage key for any rewrite: a test that fails on the new code path either has a +divergence note (the rewrite accidentally fixed a known gap — decide whether to keep the fix) or +it does not (the rewrite broke something that was correct — fix the rewrite). + +### When a new spec revision is released + +1. Update `SPEC_REVISION` and walk the new revision's changelog. +2. For each changed interaction, find its requirements (the IDs use the wire method strings the + changelog speaks in), re-audit the tests against the new text, and update `source` links and + assertions where behaviour legitimately changed. +3. New interactions get new requirements and new tests; removed interactions get their + requirements deleted along with their tests. +4. A behaviour that is correct under both revisions needs no change beyond the `source` link. + +## Writing a test + +The shortest complete example of the conventions: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) +``` + +- **The server is defined inside the test** (or in a small fixture at the top of the file when + several tests genuinely share it). The whole observable behaviour fits on one screen. +- **Test names are behaviour sentences** — they state the observable outcome, not the feature + being poked. Docstrings add the one or two sentences of context a reviewer needs, including + whether the assertion is spec-mandated, SDK-defined, or a known divergence. +- **Handlers assert their dispatch identity first** (`assert params.name == "add"`), proving the + request that arrived is the request the test sent. +- **The result proves the round trip.** Server-side observations travel back to the test through + the protocol itself (a tool returns what it saw) or through a closure-captured list; the test + asserts after the call returns. +- **Order within a test**: server handlers → server construction → client callbacks → connect → + act → assert. The test reads in the order the conversation happens. +- A registered handler or tool that a test never invokes gets a `raise NotImplementedError` body + so it cannot silently become load-bearing. + +### Choosing an assertion + +| The property under test is… | Assert with | +|---|---| +| the result of a transformation (arguments → output, exception → error result) | `result == snapshot(...)` of the full object, so any field the implementation adds or drops fails the test | +| pass-through of an opaque value (`_meta`, cursors) | identity against the same variable that was sent — a snapshot of a pass-through value only matches the input because a human checked two literals correspond | +| an error | `pytest.raises(MCPError)` and a snapshot of `exc.value.error` when the message is the SDK's own; a plain `==` on `.code` against the `mcp.types` constant when it is not | +| third-party output embedded in a result (validation messages) | the stable prefix only — never pin text that changes with a dependency upgrade | + +### Notifications and concurrency + +The client's receive loop dispatches each incoming message to completion before reading the next, +and the in-memory transport delivers everything on one ordered stream. Together these guarantee +that every notification a server handler emits before its response reaches the client callback +before the originating request returns — so tests collect notifications into a plain list and +assert after the call, with no synchronisation. The exceptions: + +- a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in + the receiving handler and awaited under `anyio.fail_after(5)`; +- the ordering guarantee does not survive transports that split messages across streams (the + streamable HTTP standalone GET stream) — see `transports/test_streamable_http.py`. + +### Coverage + +CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the +build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a +pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# pragma`, +`# type: ignore`, or `# noqa` comments; restructure instead. diff --git a/tests/interaction/__init__.py b/tests/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py new file mode 100644 index 0000000000..baca975917 --- /dev/null +++ b/tests/interaction/_connect.py @@ -0,0 +1,358 @@ +"""Transport-parametrized connection factories for the interaction suite. + +The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body +runs over each transport without naming any of them: the factory is a drop-in replacement for +constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the +server's real Starlette app through the in-process streaming bridge, so the full transport layer +(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. +""" + +import gc +import warnings +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Protocol + +import httpx +from httpx_sse import ServerSentEvent, aconnect_sse +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.client.client import Client +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.mcpserver import MCPServer +from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) +from tests.interaction.transports._bridge import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot +# exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the +# protection itself pass explicit settings (or transport_security=None to get the localhost +# auto-enable behaviour). +NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) + + +class Connect(Protocol): + """Connect a Client to a server over the transport selected by the `connect` fixture. + + Accepts the same keyword arguments as `Client` and yields the connected client. + """ + + def __call__( + self, + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, + ) -> AbstractAsyncContextManager[Client]: ... + + +@asynccontextmanager +async def connect_in_memory( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server over the in-memory transport.""" + async with Client( + server, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +@asynccontextmanager +async def connect_over_streamable_http( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's streamable HTTP app, entirely in process. + + With the defaults this is the matrix leg (stateful sessions, SSE responses); the + transport-specific tests pass `stateless_http` or `json_response` to select the other + server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so + the client's reconnection wait is a no-op). + """ + app = server.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client: + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +@asynccontextmanager +async def mounted_app( + server: Server | MCPServer, + *, + stateless_http: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, + on_request: Callable[[httpx.Request], Awaitable[None]] | None = None, + headers: dict[str, str] | None = None, +) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]: + """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client. + + Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests + use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test + speaks HTTP through the yielded client directly; for client-driven assertions the test wraps + that client in `client_via_http(http)`, which lets several `Client`s share the one mounted + session manager. `on_request` records every outgoing HTTP request before it leaves the + yielded client. + + DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the + localhost auto-enable behaviour) to test the protection itself. + """ + app = server.streamable_http_app( + stateless_http=stateless_http, + event_store=event_store, + retry_interval=retry_interval, + transport_security=transport_security, + ) + event_hooks = {"request": [on_request]} if on_request is not None else None + async with server.session_manager.run(): + async with httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers + ) as http_client: + yield http_client, server.session_manager + + +@asynccontextmanager +async def client_via_http( + http_client: httpx.AsyncClient, + *, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Connect a `Client` over an already-mounted streamable HTTP app. + + Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a + client-driven assertion can sit alongside raw-httpx assertions in the same test. The + underlying `httpx.AsyncClient` is left open when the `Client` exits. + """ + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + logging_callback=logging_callback, + message_handler=message_handler, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: + """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" + return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] + + +async def post_jsonrpc( + http: httpx.AsyncClient, body: dict[str, object], *, session_id: str | None = None +) -> tuple[httpx.Response, list[JSONRPCMessage]]: + """POST a JSON-RPC body and read its SSE response stream to completion. + + Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages + that arrived on the response's SSE stream. Only meaningful for requests the server answers + with `text/event-stream`; for error responses or 202 notification acknowledgements, use + `httpx.AsyncClient.post` directly and assert on the response. + """ + async with aconnect_sse(http, "POST", "/mcp", json=body, headers=base_headers(session_id=session_id)) as source: + events = [event async for event in source.aiter_sse()] + return source.response, parse_sse_messages(events) + + +def base_headers(*, session_id: str | None = None) -> dict[str, str]: + """Standard request headers for raw-httpx streamable-HTTP tests. + + Every well-formed request carries these (Accept covering both response representations, + Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session + ID once one exists), so a test that wants to assert a specific rejection only varies the one + header under test. + """ + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-protocol-version": LATEST_PROTOCOL_VERSION, + } + if session_id is not None: + headers["mcp-session-id"] = session_id + return headers + + +def initialize_body(request_id: int = 1) -> dict[str, object]: + """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" + params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="raw", version="0.0.0"), + ) + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True) + ).model_dump(by_alias=True, exclude_none=True) + + +async def initialize_via_http(http: httpx.AsyncClient) -> str: + """Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID. + + Validates the SSE response and sends the `notifications/initialized` follow-up, so the server + is fully ready for subsequent feature requests when this returns. + """ + async with aconnect_sse(http, "POST", "/mcp", json=initialize_body(), headers=base_headers()) as source: + assert source.response.status_code == 200 + # An event-store-backed server opens the stream with a priming event (empty data); skip it. + events = [event async for event in source.aiter_sse() if event.data] + assert len(events) == 1 + assert JSONRPCResponse.model_validate_json(events[0].data).id == 1 + session_id = source.response.headers["mcp-session-id"] + initialized = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers=base_headers(session_id=session_id), + ) + assert initialized.status_code == 202 + return session_id + + +def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: + """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. + + `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which + the SSE-specific tests need; building the app explicitly here gives both server flavours the + same routing while keeping that handle. + """ + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): + await lowlevel.run(read, write, lowlevel.create_initialization_options()) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + return app, sse + + +@asynccontextmanager +async def connect_over_sse( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" + app, _ = build_sse_app(server) + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE server transport's connect_sse runs the entire MCP session inside the GET + # request and only releases its streams after that request observes a disconnect, so the + # bridge must let the application drain rather than cancelling at close. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) + try: + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + finally: + # SseServerTransport.connect_sse hands its internal SSE-chunk receive stream to + # sse_starlette's EventSourceResponse, which never closes it when its task group is + # cancelled on disconnect (see notes/findings.md). Collect the orphan here so its + # ResourceWarning fires deterministically inside this fixture instead of at an + # arbitrary later GC. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + gc.collect() diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py new file mode 100644 index 0000000000..25833b0ca5 --- /dev/null +++ b/tests/interaction/_helpers.py @@ -0,0 +1,107 @@ +"""Shared helpers for the interaction suite. + +Keep this module small: it exists only for (a) types that every test would otherwise have to +assemble from the SDK's internals to annotate a client callback, and (b) the recording transport +used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses +them. +""" + +from types import TracebackType + +import anyio +from typing_extensions import Self + +from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ClientResult, ServerNotification, ServerRequest + +# TODO: this union is the parameter type of every client message handler (MessageHandlerFnT), +# but the SDK does not export a name for it -- writing a correctly-typed handler requires +# importing RequestResponder from mcp.shared.session and assembling the union by hand. It +# should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is +# for the request callbacks), at which point this alias can be deleted. +IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception +"""Everything a client message handler can receive.""" + + +class _RecordingReadStream: + """Delegates to a read stream, appending every received message to a log.""" + + def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + self._inner = inner + self._log = log + + async def receive(self) -> SessionMessage | Exception: + item = await self._inner.receive() + self._log.append(item) + return item + + async def aclose(self) -> None: + await self._inner.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> SessionMessage | Exception: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration from None + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class _RecordingWriteStream: + """Delegates to a write stream, appending every sent message to a log.""" + + def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + self._inner = inner + self._log = log + + async def send(self, item: SessionMessage, /) -> None: + self._log.append(item) + await self._inner.send(item) + + async def aclose(self) -> None: + await self._inner.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class RecordingTransport: + """Wraps a Transport and records every message crossing the client's transport boundary. + + `sent` holds everything the client wrote towards the server; `received` holds everything the + server delivered to the client. The recording sits at the transport seam -- the exact payloads + a real transport would serialise -- and never touches the session, so wire-level assertions + written against it survive changes to the receive path. + """ + + def __init__(self, inner: Transport) -> None: + self.inner = inner + self.sent: list[SessionMessage] = [] + self.received: list[SessionMessage | Exception] = [] + + async def __aenter__(self) -> TransportStreams: + read_stream, write_stream = await self.inner.__aenter__() + return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent) + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + return await self.inner.__aexit__(exc_type, exc_val, exc_tb) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py new file mode 100644 index 0000000000..b5897ee46d --- /dev/null +++ b/tests/interaction/_requirements.py @@ -0,0 +1,2668 @@ +"""Requirements manifest for the interaction-model test suite. + +Every user-facing behaviour the SDK must satisfy, keyed by a stable `:[:]` +ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` (a test that +proves several behaviours stacks several decorators) and `test_coverage.py` enforces the contract +in both directions: every non-deferred requirement has at least one test, and every test carries +at least one requirement. + +Sources: + spec URL -- externally mandated by the MCP specification (deep link to the section) + `sdk` -- a behavioural guarantee the SDK chose; not spec-mandated + `issue:#n` -- regression lock-in for a previously fixed bug + +The `behavior` sentence describes the REQUIRED behaviour -- what the specification (or the SDK's +own contract) says should happen. Tests always pin the SDK's current behaviour. Where current +behaviour falls short of `behavior`, the gap is recorded as data: `divergence` on entries whose +tests pin the divergent behaviour, or `deferred` on entries that are tracked but not yet covered +by a test in this suite. An entry may carry both: `divergence` records the spec-compliance gap +(issue-able) and `deferred` records why no test exists; `divergence` alone implies a test pins +the divergent behaviour. `issue` carries the tracking link for a recorded gap once one is filed. + +`deferred` reasons take one of three shapes: where the behaviour is exercised elsewhere in this +repo the reason names the covering test path; where the SDK does not implement the behaviour at +all the reason starts with "Not implemented in the SDK"; and where an interaction-level test is +planned but not yet written the reason starts with "Not yet covered here". + +`transports` records which transports a behaviour applies to (or is observable on); None means +the behaviour is transport-independent. + +The ID vocabulary and entry granularity are aligned with the TypeScript SDK's end-to-end +requirements suite, so coverage and recorded divergences can be compared across the two SDKs +entry by entry; IDs that exist in only one SDK reflect genuinely different API surface. +""" + +import re +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, TypeVar + +import pytest + +SPEC_REVISION = "2025-11-25" +SPEC_BASE_URL = f"https://modelcontextprotocol.io/specification/{SPEC_REVISION}" + +Transport = Literal["in-memory", "stdio", "streamable-http", "sse"] + +_TestFn = TypeVar("_TestFn", bound=Callable[..., object]) + +_SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") + +_TASKS_DEFERRAL = ( + "Tasks are experimental and the spec is being substantially revised; python task behaviour is " + "covered by tests/experimental/tasks/ until the next spec revision settles." +) + + +@dataclass(frozen=True, kw_only=True) +class Divergence: + """A documented gap between the SDK behaviour this suite pins and what `source` mandates.""" + + note: str + issue: str | None = None + + +@dataclass(frozen=True, kw_only=True) +class Requirement: + """A single testable behaviour and the provenance of why it must hold.""" + + source: str + behavior: str + transports: tuple[Transport, ...] | None = None + divergence: Divergence | None = None + deferred: str | None = None + issue: str | None = None + + def __post_init__(self) -> None: + if not _SOURCE_PATTERN.fullmatch(self.source): + raise ValueError(f"source must be a specification URL, 'sdk', or 'issue:#n', got {self.source!r}") + + +REQUIREMENTS: dict[str, Requirement] = { + # ═══════════════════════════════════════════════════════════════════════════ + # Lifecycle & version negotiation + # ═══════════════════════════════════════════════════════════════════════════ + "lifecycle:capability:client-not-declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects sending notifications or registering handlers for capabilities it did not declare." + ), + divergence=Divergence( + note=( + "The client does not check its own declared capabilities before sending notifications or " + "serving callbacks; nothing prevents a caller from violating the spec's SHOULD." + ), + ), + deferred=( + "Not implemented in the SDK: the client does not check its own declared capabilities before " + "sending notifications or serving callbacks." + ), + ), + "lifecycle:capability:server-not-advertised": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects calls to methods (e.g. resources/list) for capabilities the server did not advertise." + ), + divergence=Divergence( + note=( + "The client sends any request regardless of the server's advertised capabilities and " + "surfaces whatever the server answers; the spec's SHOULD is not enforced." + ), + ), + deferred=( + "Not implemented in the SDK: the client sends any request regardless of the server's " + "advertised capabilities and surfaces whatever the server answers." + ), + ), + "lifecycle:initialize:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Connecting sends initialize with the protocol version, client capabilities, and client " + "info; the server responds with its own and the connection is established." + ), + ), + "lifecycle:initialize:server-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The initialize result identifies the server: name and version, plus title when declared.", + ), + "lifecycle:initialize:instructions": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="A server may include an instructions string in the initialize result; the client exposes it.", + ), + "lifecycle:initialize:capabilities:from-handlers": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The server advertises a capability for each feature area it has a registered handler for, " + "and omits the capability for areas it does not." + ), + ), + "lifecycle:initialize:capabilities:minimal": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with no feature handlers advertises no feature capabilities.", + ), + "lifecycle:initialize:client-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The client's name, version, and title are visible to server handlers after initialization.", + ), + "lifecycle:initialize:client-capabilities": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The client capabilities visible to the server reflect which client callbacks are configured " + "(sampling, elicitation, roots)." + ), + ), + "lifecycle:initialized-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "After successful initialization, the client sends exactly one initialized notification, " + "before any non-ping request." + ), + ), + "lifecycle:ping": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="ping in either direction returns an empty result.", + ), + "ping:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A client-initiated ping receives an empty result from the server.", + ), + "ping:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A server-initiated ping receives an empty result from the client.", + ), + "lifecycle:requests-before-initialized": Requirement( + source="sdk", + behavior=( + "A request other than ping sent before the initialization handshake completes is rejected with an error." + ), + ), + "lifecycle:pre-initialization-ordering": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Before initialization completes, the client sends no requests other than pings, and the " + "server sends no requests other than pings and logging." + ), + divergence=Divergence( + note=( + "The server's send methods (create_message / elicit_form / list_roots) do not check " + "initialization state before sending; on the client side, Client always completes the " + "handshake before any caller code runs." + ), + ), + deferred=( + "Not implemented in the SDK: neither side enforces sender-side restraint. The server's send " + "methods (create_message / elicit_form / list_roots) do not check initialization state before " + "sending, and there is no natural hook to issue a server-to-client request between the " + "initialize response and the initialized notification through the public API; on the client " + "side, Client always completes the handshake before any caller code runs." + ), + ), + "lifecycle:version:downgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server returns an older supported protocol version, the client downgrades to it " + "and the connection succeeds at that version." + ), + ), + "lifecycle:version:match": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server supports the requested protocol version it echoes that version in the " + "initialize result, and the connection proceeds at that version." + ), + ), + "lifecycle:version:server-fallback-latest": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "An initialize request carrying a protocol version the server does not support is answered " + "with another version the server supports — the latest one — rather than an error." + ), + ), + "lifecycle:version:reject-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "A client that receives an initialize response carrying a protocol version it does not " + "support fails initialization with an error rather than proceeding with the session." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Protocol primitives: cancellation, timeout, progress, errors, _meta + # ═══════════════════════════════════════════════════════════════════════════ + "protocol:request-id:unique": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Every request sent on a session carries a unique, non-null string or integer id; ids are " + "never reused within the session." + ), + ), + "protocol:notifications:no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic#notifications", + behavior=( + "Notifications are never answered: every message the server delivers is either the response " + "to a request the client sent or a notification carrying no id." + ), + ), + "protocol:cancel:abort-signal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#cancellation-flow", + behavior=( + "Cancelling an in-flight request through the client API sends notifications/cancelled with " + "the request id and fails the local call." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side API to cancel an in-flight " + "request; cancellation requires hand-constructing the notification (which is how " + "protocol:cancel:in-flight exercises the receiving side)." + ), + ), + "protocol:cancel:handler-abort-propagates": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="On the receiving side, a cancellation notification stops the running request handler.", + ), + "protocol:cancel:in-flight": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A cancellation notification for an in-flight request stops the server-side handler, and the " + "receiver does not send a response for the cancelled request." + ), + divergence=Divergence( + note=( + "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " + "request; the server sends an error response (code 0, 'Request cancelled'), which is what " + "unblocks the SDK client's pending call." + ), + ), + ), + "protocol:cancel:initialize-not-cancellable": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="The client never sends notifications/cancelled for the initialize request.", + deferred=( + "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " + "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." + ), + ), + "protocol:cancel:late-response-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A response that arrives after the sender issued notifications/cancelled is ignored; the " + "request stays failed and no error is raised." + ), + divergence=Divergence( + note=( + "A response whose id matches no in-flight request is delivered to the message handler " + "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " + "same code path; tested in its unknown-id form because that is deterministic without the " + "client-side cancellation API the SDK does not yet provide." + ), + ), + ), + "protocol:cancel:server-survives": Requirement( + source="sdk", + behavior="The session continues to serve new requests after an earlier request was cancelled.", + ), + "protocol:cancel:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " + "cancels it, and the client stops processing the cancelled request." + ), + divergence=Divergence( + note=( + "Abandoning a server-side send_request emits no cancellation notification, and the client " + "could not act on one anyway: client callbacks run inline in the receive loop, so a " + "cancellation is not even read until the callback has finished." + ), + ), + deferred=( + "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " + "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " + "the client could not act on one anyway because client callbacks run inline in the receive " + "loop, so a cancellation would not even be read until the callback had already finished." + ), + ), + "protocol:cancel:unknown-id-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", + behavior=( + "The receiver silently ignores a cancellation notification referencing an unknown or " + "already-completed request id; no error response is sent and no exception is raised." + ), + ), + "protocol:cancel:sender-targeting": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "Cancellation notifications reference only requests that were previously issued in the same " + "direction and are believed to still be in flight." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side cancel API to drive (see " + "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin." + ), + ), + "protocol:error:connection-closed": Requirement( + source="sdk", + behavior="Closing the transport fails all in-flight requests with a connection-closed error.", + ), + "protocol:error:internal-error": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior=( + "An unhandled exception in a request handler is returned to the caller as JSON-RPC error " + "-32603 Internal error." + ), + divergence=Divergence( + note=( + "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " + "leaks str(exc) as the error message." + ), + ), + ), + "protocol:error:invalid-params": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request with malformed params is answered with JSON-RPC error -32602 Invalid params.", + ), + "protocol:error:method-not-found": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + ), + "protocol:meta:related-task": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="Messages may carry related-task _meta associating them with a task.", + deferred=_TASKS_DEFERRAL, + ), + "meta:request-to-handler": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object the client attaches to a request is visible to the server handler.", + ), + "meta:result-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object a handler attaches to its result is delivered to the client.", + ), + "protocol:progress:callback": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a handler during a request are delivered to the caller's " + "progress callback, in order, with their progress, total, and message." + ), + ), + "protocol:progress:token-injected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Supplying a progress callback attaches a progress token to the outgoing request, which the " + "server-side handler can observe in its request metadata." + ), + ), + "protocol:progress:token-unique": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), + ), + "protocol:progress:monotonic": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "The progress value increases with each notification for a given token, even when the total is unknown." + ), + divergence=Divergence( + note=( + "The spec MUST is not enforced: progress values are not validated on either side, so a " + "handler that emits non-increasing values has them forwarded to the callback unchanged." + ), + ), + ), + "protocol:progress:stops-after-completion": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", + behavior="Progress notifications for a token stop once the associated request completes.", + divergence=Divergence( + note=( + "send_progress_notification does not check whether the token's request has already " + "completed; the late notification is sent and reaches the client." + ), + ), + ), + "protocol:progress:late-dropped-by-client": Requirement( + source="sdk", + behavior=( + "A progress notification that arrives after its request has completed is not delivered to the " + "original progress callback." + ), + ), + "protocol:progress:no-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Without a progress callback no token is attached, and a handler that reports progress anyway " + "sends nothing." + ), + ), + "protocol:progress:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="A progress notification sent by the client is delivered to the server's progress handler.", + ), + "protocol:timeout:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "A request that exceeds its read timeout fails with a request-timeout error instead of " + "waiting forever for the response." + ), + ), + "protocol:timeout:max-total": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A maximum total timeout is enforced even when progress notifications keep arriving.", + divergence=Divergence( + note=( + "There is no maximum-total-timeout option; only the per-request read timeout exists, so the " + "spec's SHOULD that an overall maximum is always enforced cannot be satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: there is no maximum-total-timeout option; only the per-request " + "read timeout exists." + ), + ), + "protocol:timeout:reset-on-progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="When configured to do so, each progress notification resets the request's read timeout.", + deferred=( + "Not implemented in the SDK: progress notifications do not reset the request read timeout and " + "no option exists to enable that." + ), + ), + "protocol:timeout:sends-cancellation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "When a request times out, the sender issues notifications/cancelled for that request before " + "failing the local call." + ), + divergence=Divergence( + note=( + "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." + ), + ), + ), + "protocol:timeout:session-survives": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="The session continues to serve new requests after an earlier request timed out.", + ), + "protocol:timeout:session-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A session-level read timeout applies to every request that does not override it.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools + # ═══════════════════════════════════════════════════════════════════════════ + "tools:call:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#audio-content", + behavior="A tool result can carry audio content: base64 data with a mimeType.", + ), + "tools:call:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", + behavior="A tool result can carry an embedded resource with full text or blob contents.", + ), + "tools:call:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#image-content", + behavior="A tool result can carry image content: base64 data with a mimeType.", + ), + "tools:call:content:mixed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-result", + behavior="A tool result can carry multiple content blocks of different types; order is preserved.", + ), + "tools:call:content:resource-link": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior="A tool result can carry a resource_link content block referencing a resource by URI.", + ), + "tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", + ), + "tools:call:concurrent": Requirement( + source="sdk", + behavior=( + "Multiple tool calls in flight on one session are dispatched concurrently, and each caller " + "receives the response to its own request." + ), + ), + "tools:call:elicitation-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#user-interaction-model", + behavior=( + "A tool handler that issues an elicitation receives the client's result and can embed it in " + "the tool call result." + ), + ), + "tools:call:is-error": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "A tool execution failure is returned as a result with isError true and the failure described " + "in content, not as a JSON-RPC error." + ), + ), + "tools:call:logging-mid-execution": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "Log notifications emitted by a tool handler during execution reach the client's logging " + "callback before the tool result returns." + ), + ), + "tools:call:progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a tool handler reach the caller's progress callback before " + "the tool result returns." + ), + ), + "tools:call:sampling-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A tool handler that issues a sampling request receives the client's completion and can embed " + "it in the tool call result." + ), + ), + "tools:call:structured-content": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool result can carry structuredContent alongside content; the client receives both.", + ), + "tools:call:structured-content:text-mirror": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool returning structured content also returns the serialized JSON as a text content block.", + ), + "tools:call:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + ), + "tools:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#capabilities", + behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", + ), + "tools:input-schema:json-schema-2020-12": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "A tool registered with a JSON Schema 2020-12 inputSchema (nested objects, $defs references) " + "is discoverable and callable." + ), + ), + "tools:input-schema:preserve-additional-properties": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema additionalProperties as registered.", + ), + "tools:input-schema:preserve-defs": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema $defs as registered.", + ), + "tools:input-schema:preserve-schema-dialect": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves the inputSchema $schema dialect URI as registered.", + ), + "tools:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", + behavior=( + "When the tool set changes, a server that declared the tools listChanged capability sends " + "notifications/tools/list_changed and it reaches the client's handler." + ), + ), + "tools:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#listing-tools", + behavior="tools/list returns the registered tools with name, description, and inputSchema.", + ), + "tools:list:metadata": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " + "are delivered to the client unchanged." + ), + ), + "tools:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "tools/list supports cursor pagination: the nextCursor returned by a list handler round-trips " + "back to the handler as an opaque cursor until the listing is exhausted." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "client:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="The client skips structured-content validation when the tool result has isError true.", + ), + "client:output-schema:validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "A tool result whose structuredContent does not conform to the tool's declared outputSchema " + "is rejected by the client: the call raises instead of returning the invalid result." + ), + ), + "client:output-schema:missing-structured": Requirement( + source="sdk", + behavior="A tool that declares an output schema but returns no structuredContent fails client-side validation.", + ), + "client:output-schema:auto-list": Requirement( + source="sdk", + behavior=( + "Calling a tool whose output schema is not yet cached issues an implicit tools/list to " + "populate the cache; subsequent calls of the same tool do not." + ), + divergence=Divergence( + note=( + "Design concern rather than spec violation: the implicit request is invisible to the " + "caller, and against a server that registers only on_call_tool a successful call surfaces " + "as METHOD_NOT_FOUND from a tools/list the caller never asked for." + ), + ), + ), + "mcpserver:output-schema:missing-structured": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior="A tool with an output schema whose function returns no structured content produces a server error.", + ), + "mcpserver:output-schema:server-validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "MCPServer validates structured content against the tool's output schema before returning; a " + "mismatch produces a server error." + ), + ), + "mcpserver:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="Server-side output schema validation is skipped when the tool returns an isError result.", + ), + "mcpserver:tool:duplicate-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-names", + behavior="Registering a tool with a name already in use is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; " + "warn_on_duplicate_tools defaults to True and warning is the only effect -- there is " + "no rejection mode." + ), + ), + ), + "mcpserver:tool:extra": Requirement( + source="sdk", + behavior=( + "Tool functions can access request metadata (request id, client params, session) through the " + "Context parameter." + ), + ), + "mcpserver:tool:handler-throws": Requirement( + source="sdk", + behavior=( + "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " + "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + ), + ), + "mcpserver:tool:input-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content) without invoking the function." + ), + ), + "mcpserver:tool:naming-validation": Requirement( + source="sdk", + behavior="Tool names that violate the spec's naming rules are rejected at registration time.", + deferred=( + "Not implemented in the SDK: MCPServer accepts any string as a tool name; there is no " + "spec-naming-rules check at registration time." + ), + ), + "mcpserver:tool:output-schema:model": Requirement( + source="sdk", + behavior=( + "A tool returning a typed model advertises a matching generated outputSchema and returns the " + "model's fields as structuredContent alongside a serialised text block." + ), + ), + "mcpserver:tool:output-schema:wrapped": Requirement( + source="sdk", + behavior=( + "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " + "structuredContent, with a matching generated outputSchema." + ), + ), + "mcpserver:tool:schema-variants": Requirement( + source="sdk", + behavior=( + "Tool input schemas generated from complex parameter types (unions, nested models, " + "constrained types) validate and coerce arguments before the function runs." + ), + ), + "mcpserver:tool:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name that was never registered returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " + "params); MCPServer reports a tool execution error (isError true) instead. The low-level " + "path follows the spec example (see tools:call:unknown-name)." + ), + ), + ), + "mcpserver:tool:url-elicitation-error": Requirement( + source="sdk", + behavior=( + "A tool function that raises the URL-elicitation-required error surfaces to the caller as " + "error -32042 with the elicitation parameters intact." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer: Context helpers (SDK) + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:context:logging": Requirement( + source="sdk", + behavior=( + "The Context logging helpers (debug/info/warning/error) send log message notifications at the " + "corresponding severity." + ), + ), + "mcpserver:context:progress": Requirement( + source="sdk", + behavior=( + "Context.report_progress sends a progress notification against the requesting client's progress token." + ), + ), + "mcpserver:context:elicit": Requirement( + source="sdk", + behavior=( + "Context.elicit sends a form elicitation built from a typed schema and returns a typed " + "accepted/declined/cancelled result." + ), + ), + "mcpserver:context:read-resource": Requirement( + source="sdk", + behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources + # ═══════════════════════════════════════════════════════════════════════════ + "resources:annotations": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#annotations", + behavior=( + "Resource annotations (audience, priority) supplied by the server round-trip to the client " + "in the list result." + ), + ), + "resources:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#capabilities", + behavior=( + "A server with resource handlers advertises the resources capability, including the subscribe " + "sub-flag when a subscribe handler is registered." + ), + ), + "resources:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", + behavior=( + "When the resource set changes, a server that declared the resources listChanged capability " + "sends notifications/resources/list_changed and it reaches the client's handler." + ), + ), + "resources:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#listing-resources", + behavior=( + "resources/list returns the registered resources with uri, name, and the optional descriptive " + "fields supplied by the server." + ), + ), + "resources:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/list supports cursor pagination.", + ), + "resources:read:blob": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns binary contents base64-encoded in blob.", + ), + "resources:read:template-vars": Requirement( + source="sdk", + behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", + ), + "resources:read:text": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + ), + "resources:read:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for an unknown URI returns JSON-RPC error -32002 (resource not found).", + ), + "resources:subscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + ), + "resources:subscribe:capability-required": Requirement( + source="sdk", + behavior=( + "resources/subscribe to a server that did not advertise the subscribe capability is rejected with an error." + ), + ), + "resources:subscribe:updated": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/subscribe, changes to that resource send notifications/resources/updated.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state linking subscribe to " + "updated notifications; emitting updates is entirely handler code. The two halves are pinned " + "separately by resources:subscribe and resources:updated-notification." + ), + ), + "resources:templates:list": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#resource-templates", + behavior=( + "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." + ), + ), + "resources:templates:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/templates/list supports cursor pagination.", + ), + "resources:unsubscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "resources/unsubscribe delivers the URI to the server's unsubscribe handler and returns an empty result." + ), + ), + "resources:unsubscribe:stops-updates": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/unsubscribe the server stops sending updated notifications for that URI.", + deferred=( + "The SDK keeps no subscription state -- emitting updated notifications is entirely handler " + "code -- so there is no SDK behaviour to pin beyond the unsubscribe request reaching the " + "handler (covered by resources:unsubscribe)." + ), + ), + "resources:updated-notification": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "A resources/updated notification sent by the server reaches the client carrying the URI of " + "the changed resource." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:resource:duplicate-name": Requirement( + source="sdk", + behavior="Registering a resource or template with a duplicate identifier is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name). " + "Templates differ: a duplicate uri_template silently replaces the first with no warning." + ), + ), + ), + "mcpserver:resource:read-throws-surfaced": Requirement( + source="sdk", + behavior="A resource function that raises is surfaced to the caller as a JSON-RPC error response.", + ), + "mcpserver:resource:static": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " + "served by resources/read at that URI." + ), + ), + "mcpserver:resource:template": Requirement( + source="sdk", + behavior=( + "A function registered with a URI template is listed by resources/templates/list and matched " + "by resources/read, receiving the parameters extracted from the requested URI." + ), + ), + "mcpserver:resource:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for a URI matching no registered resource returns JSON-RPC error -32002.", + divergence=Divergence( + note=( + "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " + "the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts + # ═══════════════════════════════════════════════════════════════════════════ + "prompts:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#capabilities", + behavior="A server with a list_prompts handler advertises the prompts capability in its initialize result.", + ), + "prompts:get:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#audio-content", + behavior="Prompt messages may contain audio content with base64 data and a mimeType.", + ), + "prompts:get:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#embedded-resources", + behavior="Prompt messages may contain embedded resource content.", + ), + "prompts:get:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#image-content", + behavior="Prompt messages may contain image content.", + ), + "prompts:get:missing-required-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get omitting a required argument returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " + "which the low-level server converts to error code 0 with the exception text as the message." + ), + ), + ), + "prompts:get:multi-message": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="A prompt can return multiple messages mixing user and assistant roles; order is preserved.", + ), + "prompts:get:no-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get with no arguments returns the prompt's messages.", + ), + "prompts:get:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for an unknown prompt name returns JSON-RPC error -32602 (Invalid params).", + ), + "prompts:get:with-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + ), + "prompts:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", + behavior=( + "When the prompt set changes, a server that declared the prompts listChanged capability sends " + "notifications/prompts/list_changed and it reaches the client's handler." + ), + ), + "prompts:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", + behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + ), + "prompts:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="prompts/list supports cursor pagination.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:prompt:args-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", + behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", + ), + "mcpserver:prompt:decorated": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.prompt() is listed with arguments derived from its signature " + "and rendered into prompt messages by prompts/get." + ), + ), + "mcpserver:prompt:duplicate-name": Requirement( + source="sdk", + behavior="Registering a duplicate prompt name is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + ), + ), + ), + "mcpserver:prompt:optional-args": Requirement( + source="sdk", + behavior="A prompt with optional arguments can be fetched without supplying them.", + ), + "mcpserver:prompt:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for a name that was never registered returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " + "ValueError, which the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Completion + # ═══════════════════════════════════════════════════════════════════════════ + "completion:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior="A server with a completion handler advertises the completions capability in its initialize result.", + ), + "completion:complete:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior=( + "A server with no completion handler does not advertise the completions capability and rejects " + "completion/complete with METHOD_NOT_FOUND." + ), + ), + "completion:context-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + ), + "completion:error:invalid-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#error-handling", + behavior=( + "completion/complete with a ref naming an unknown prompt or non-matching resource URI returns " + "JSON-RPC error -32602 (Invalid params)." + ), + ), + "completion:prompt-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + ), + "completion:resource-template-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + ), + "completion:result-shape": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#completion-results", + behavior="The completion result carries values (at most 100), an optional total, and an optional hasMore flag.", + ), + "mcpserver:completion:capability-auto": Requirement( + source="sdk", + behavior=( + "MCPServer advertises the completions capability when at least one completion source is " + "registered, and omits it otherwise." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Logging + # ═══════════════════════════════════════════════════════════════════════════ + "logging:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + behavior=( + "A server that emits log message notifications declares the logging capability in its initialize result." + ), + divergence=Divergence( + note=( + "MCPServer registers no setLevel handler, so capability derivation leaves logging unset " + "even though the Context helpers send log message notifications." + ), + ), + ), + "logging:message:all-levels": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + ), + "logging:message:fields": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "A log message sent by a server handler is delivered to the client's logging callback with its " + "severity level, logger name, and data, in the order the server sent them." + ), + ), + "logging:message:filtered": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="After logging/setLevel, log messages below the configured level are not sent.", + divergence=Divergence( + note=( + "Neither MCPServer (which rejects logging/setLevel with method-not-found) nor the " + "low-level Server (which leaves the handler entirely to the author) implements any " + "filtering; messages are delivered at every severity regardless of the requested level." + ), + ), + ), + "logging:set-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + ), + "logging:set-level:invalid-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#error-handling", + behavior="logging/setLevel with an invalid level value returns JSON-RPC error -32602 (Invalid params).", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Sampling (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "sampling:capability:declare": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A client that handles sampling requests advertises the sampling capability in its initialize request." + ), + ), + "sampling:create:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A sampling/createMessage request from a server handler is answered by the client's sampling " + "callback, and the callback's result (role, content, model, stopReason) is returned to the handler." + ), + ), + "sampling:create:include-context": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior="The includeContext value supplied by the server reaches the client callback intact.", + ), + "sampling:context:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "The server does not use includeContext values thisServer or allServers unless the client " + "declared the sampling.context capability." + ), + divergence=Divergence( + note=( + "include_context is forwarded regardless of the client's declared sampling.context " + "capability; the server-side validator only checks tools/tool_choice." + ), + ), + ), + "sampling:create:model-preferences": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#model-preferences", + behavior=( + "The model preferences supplied by the server (hints and the cost, speed, and intelligence " + "priorities) reach the client callback intact." + ), + ), + "sampling:create:system-prompt": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior="The system prompt supplied by the server reaches the client callback intact.", + ), + "sampling:create:tools": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " + "with a toolUse stop reason returns to the requesting handler." + ), + deferred=( + "Not implemented in the SDK: Client does not expose ClientSession's sampling_capabilities " + "parameter, so a client can never declare sampling.tools and the server-side validator " + "rejects every tool-enabled request before it is sent." + ), + ), + "sampling:create-message:audio-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#audio-content", + behavior="Sampling messages can carry audio content: base64 data with a mimeType.", + ), + "sampling:create-message:image-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#image-content", + behavior="Sampling messages can carry image content: base64 data with a mimeType.", + ), + "sampling:create-message:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A sampling request to a client that did not declare the sampling capability fails with an " + "error rather than hanging or being silently dropped; the spec names no error code for this case." + ), + ), + "sampling:error:user-rejected": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#error-handling", + behavior=( + "A sampling request the user rejects is answered with a JSON-RPC error (the spec's code for " + "this case is -1, 'User rejected sampling request'), surfaced to the requesting handler as an MCPError." + ), + ), + "sampling:message:content-cardinality": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling", + behavior="A sampling message's content may be a single block or an array of blocks.", + ), + "sampling:result:no-tools-single-content": Requirement( + source="sdk", + behavior=( + "When the request carries no tools, a sampling callback result whose content is an array is " + "rejected by the client." + ), + divergence=Divergence( + note=( + "The client does not validate the callback result against the request shape; an array-content " + "result for a tool-free request is accepted client-side and surfaces as a raw " + "pydantic.ValidationError from the server's response parsing (send_request) instead." + ), + ), + ), + "sampling:result:with-tools-array-content": Requirement( + source="sdk", + behavior=( + "When the request includes tools, the client accepts a callback result whose content is an " + "array including tool_use blocks." + ), + deferred=( + "Not implemented in the SDK: requires declaring sampling.tools, which the high-level client " + "cannot do (see sampling:create:tools)." + ), + ), + "sampling:tool-result:no-mixed-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-result-messages", + behavior=( + "A user sampling message that carries tool_result content contains only tool_result blocks; " + "mixing tool_result with text, image, or audio content is rejected as invalid." + ), + ), + "sampling:tool-use:result-balance": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-use-and-result-balance", + behavior=( + "Every assistant tool_use block in a sampling request must be matched by a tool_result with " + "the same id in the following user message; an unmatched tool_use is rejected with Invalid params." + ), + ), + "sampling:tools:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " + "by the server before anything reaches the wire (the SDK surfaces this as an Invalid params error)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Elicitation (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "elicitation:capability:empty-is-form": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior="A client advertising an empty elicitation capability accepts form-mode elicitation requests.", + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares explicit " + "form and url sub-capabilities, so an empty elicitation capability cannot be produced through " + "the public API." + ), + ), + "elicitation:capability:mode-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "The client answers elicitation requests for a mode it did not advertise with JSON-RPC error " + "-32602 (Invalid params)." + ), + deferred=( + "Not implemented in the SDK: a client cannot be configured form-only or url-only, so the " + "per-mode mismatch error cannot arise (see elicitation:url:not-supported)." + ), + ), + "elicitation:capability:server-respects-mode": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "The server refuses to send an elicitation request with a mode the connected client did not " + "declare in its capabilities." + ), + divergence=Divergence( + note=( + "The server does not check the client's declared elicitation modes before sending " + "elicitation/create; the spec's MUST NOT is not enforced." + ), + ), + ), + "elicitation:form:action:accept": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A form-mode elicitation answered with action 'accept' returns the user's content to the " + "requesting handler." + ), + ), + "elicitation:form:action:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", + ), + "elicitation:form:action:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", + ), + "elicitation:form:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + behavior=( + "A form-mode elicitation delivers the message and requested schema to the client callback " + "exactly as the server sent them." + ), + ), + "elicitation:form:defaults": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Optional default values declared in a form-mode requested schema are pre-populated into the " + "form presented to the user." + ), + deferred=( + "Not implemented in the SDK: there is no form-rendering layer that could pre-populate " + "defaults; client callbacks receive the requested schema as-is." + ), + ), + "elicitation:form:mode-omitted-default": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#elicitation-requests", + behavior="An elicitation request with no mode field is treated as form mode by the client.", + ), + "elicitation:form:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "An elicitation request to a client that did not declare the elicitation capability is " + "answered with -32602 Invalid params." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32602.", + ), + ), + "elicitation:form:schema:enum-variants": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Requested-schema enum fields (including titled and multi-select variants) reach the client " + "callback as sent." + ), + ), + "elicitation:form:schema:primitives": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior="Requested-schema fields may be string (with format), number or integer, or boolean.", + ), + "elicitation:form:schema:restricted-subset": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Form-mode requested schemas are flat objects with primitive-typed properties only; nested " + "structures and arrays of objects are not used." + ), + divergence=Divergence( + note=( + "Nothing restricts or validates the requested-schema shape on the sending side; a server " + "can send nested or non-primitive schemas and the SDK forwards them unchanged." + ), + ), + ), + "elicitation:form:response-validation": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", + behavior=( + "Accepted form-mode content is validated against the requested schema: the client validates " + "the response before sending and the server validates the content it receives." + ), + divergence=Divergence( + note="Accepted elicitation content passes through unvalidated on both sides.", + ), + ), + "elicitation:url:action:accept-no-content": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " + "response carries no content (accept means the user agreed to visit the URL, not that the " + "interaction completed)." + ), + ), + "elicitation:url:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + behavior=( + "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " + "the server sent them." + ), + ), + "elicitation:url:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + ), + "elicitation:url:complete-notification": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "An elicitation/complete notification sent by the server after an out-of-band elicitation " + "finishes reaches the client carrying the elicitationId." + ), + ), + "elicitation:url:complete-unknown-ignored": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "The client ignores an elicitation/complete notification referencing an unknown or " + "already-completed elicitationId without error." + ), + ), + "elicitation:url:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with decline returns the action with no content.", + ), + "elicitation:url:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " + "Invalid params error." + ), + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares both the " + "form and url sub-capabilities, so a form-only client cannot be constructed." + ), + ), + "elicitation:url:required-error": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A handler that cannot proceed without a URL elicitation rejects the request with error " + "-32042, carrying the pending elicitations in the error data." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Roots (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "roots:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + ), + "roots:list-changed:client-emits": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior=( + "A client that declared roots.listChanged sends notifications/roots/list_changed when its set " + "of roots changes." + ), + deferred=( + "Not implemented in the SDK: the client does not own the root set (it calls back to the host " + "via list_roots_callback), so there is no mutation it could observe to auto-emit on; the SDK " + "provides send_roots_list_changed() for the host to call when its roots change, and that " + "emission path is covered by roots:list-changed." + ), + ), + "roots:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior=( + "A roots/list request from a server handler is answered by the client's roots callback, and " + "the returned roots (uri, name) reach the handler." + ), + ), + "roots:list:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + ), + "roots:list:empty": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior="An empty roots list is a valid response and reaches the handler as such.", + ), + "roots:list:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior=( + "A roots/list request to a client that did not declare the roots capability is answered with " + "-32601 Method not found." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32601.", + ), + ), + "roots:uri:file-scheme": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root", + behavior="Every root returned by the client identifies itself with a file:// URI.", + deferred=( + "Schema-level validation: the FileUrl type on Root.uri rejects any non-file:// scheme at " + "construction and at parse, so a non-conforming root cannot reach the wire from either side; " + "type-level coverage belongs in tests/test_types.py rather than this interaction suite." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # list_changed & dynamic registration + # ═══════════════════════════════════════════════════════════════════════════ + "client:list-changed:auto-refresh": Requirement( + source="sdk", + behavior=( + "A client configured to react to list_changed notifications automatically re-fetches the " + "corresponding list and delivers the fresh result to its callback." + ), + deferred=( + "Not implemented in the SDK: the client has no list-changed auto-refresh mechanism; " + "notifications are only delivered to the message handler." + ), + ), + "client:list-changed:capability-gated": Requirement( + source="sdk", + behavior=( + "The client does not activate list-changed handling for a kind the server did not advertise " + "with listChanged true." + ), + deferred="Not implemented in the SDK: no client-side list-changed handling exists to gate.", + ), + "client:list-changed:signal-only": Requirement( + source="sdk", + behavior="A client configured for signal-only list-changed handling is notified without auto-refreshing.", + deferred="Not implemented in the SDK: no client-side list-changed handling exists.", + ), + "mcpserver:list-changed:debounce": Requirement( + source="sdk", + behavior=( + "Bursts of registration changes on MCPServer are debounced into one list_changed notification per kind." + ), + deferred=( + "Not implemented in the SDK: MCPServer does not send list_changed notifications on " + "registration changes at all (see mcpserver:register:post-connect), so there is nothing to " + "debounce." + ), + ), + "mcpserver:register:post-connect": Requirement( + source="sdk", + behavior=( + "A tool, resource, or prompt registered or removed after the client connected appears in (or " + "disappears from) the corresponding list results, and the change is announced with a " + "list_changed notification." + ), + divergence=Divergence( + note=( + "MCPServer never sends list_changed notifications on registration changes, so a connected " + "client cannot learn that the set changed without polling." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Pagination + # ═══════════════════════════════════════════════════════════════════════════ + "pagination:exhaustion": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "Following nextCursor until it is absent yields every page exactly once; a result without " + "nextCursor ends the sequence." + ), + ), + "pagination:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#error-handling", + behavior="A list request with an invalid cursor returns JSON-RPC error -32602 (Invalid params).", + ), + "pagination:client:cursor-handling": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#implementation-guidelines", + behavior=( + "The client treats cursors as opaque tokens — it does not parse, modify, or persist them — " + "and does not assume a fixed page size." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tasks (experimental) + # ═══════════════════════════════════════════════════════════════════════════ + "tasks:auth:context-isolation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-isolation-and-access-control", + behavior=( + "When an authorization context is available, task operations are scoped to the context that " + "created the task: other contexts cannot get it, retrieve its result, cancel it, or see it in " + "tasks/list." + ), + transports=("streamable-http",), + deferred=_TASKS_DEFERRAL, + ), + "tasks:bidirectional": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#definitions", + behavior="Task APIs are bidirectional: the server may create, get, list, and cancel tasks on the client.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:no-handler-abort": Requirement( + source="sdk", + behavior=( + "tasks/cancel marks the task cancelled without aborting the originating request handler " + "(the spec says receivers SHOULD attempt to stop execution)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:remains-cancelled": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior=( + "After tasks/cancel, the task remains cancelled even if the underlying handler subsequently " + "completes or fails." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:terminal-rejected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a task already in a terminal state returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a working task transitions it to cancelled and returns the updated task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:ttl-honored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#ttl-and-resource-management", + behavior=( + "tasks/get responses include the actual ttl applied by the receiver (or null for unlimited); " + "the create-task result carries the same value." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:via-tool-call": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#creating-tasks", + behavior="A task-augmented tools/call returns a create-task result instead of the tool result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:get": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#getting-tasks", + behavior="tasks/get returns the task's current status, ttl, timestamps, and status message.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:initial-working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-lifecycle", + behavior="A newly created task has status 'working'.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:input-required": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "While a task awaits a side-channel client response its status is input_required; once the " + "response arrives the task leaves input_required (typically returning to working)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/list with an invalid cursor returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#listing-tasks", + behavior="tasks/list returns created tasks and supports cursor pagination.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:no-capability:ignore-task-param": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-support-and-handling", + behavior=( + "A receiver that did not declare task capability for a request type processes the request " + "normally and returns the ordinary result, ignoring the task augmentation." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:progress:after-create": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-progress-notifications", + behavior=( + "After the create-task result, progress notifications keyed to the original progress token " + "continue to reach the caller until the task is terminal." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:request-cancel:no-task-cancel": Requirement( + source="sdk", + behavior="A cancellation notification for the originating request does not auto-cancel the created task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:failed": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-execution-errors", + behavior="tasks/result for a failed task returns the failure result (isError true).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:related-task-meta": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="The tasks/result response carries related-task _meta naming the requested task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:terminal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#result-retrieval", + behavior="tasks/result for a completed task returns the stored result of the original request type.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drain-fifo": Requirement( + source="sdk", + behavior="tasks/result drains queued related-task messages in FIFO order before returning the final result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drop-on-cancel": Requirement( + source="sdk", + behavior="When a task is cancelled before tasks/result, queued related-task messages are dropped.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:elicitation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "An elicitation issued mid-task is delivered through the tasks/result side-channel, and the " + "client's response routes back to the handler." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:queue": Requirement( + source="sdk", + behavior=( + "Server-to-client requests with related-task metadata sent while no tasks/result is open are queued." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:sampling": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "A sampling request issued mid-task is delivered through the tasks/result side-channel, and " + "the client's response routes back to the task." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:stream": Requirement( + source="sdk", + behavior=( + "Calling tasks/result while the task is working streams related-task messages as they are " + "produced, then returns the result." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:status-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-notification", + behavior="Task status notifications deliver status updates carrying the full task fields.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:forbidden-with-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=( + "A task-augmented tools/call on a tool that does not support tasks returns Method not found (-32601)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:required-no-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=("A plain tools/call on a tool that requires task augmentation returns Method not found (-32601)."), + deferred=_TASKS_DEFERRAL, + ), + "tasks:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/get, tasks/result, and tasks/cancel for an unknown task id return Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Transports (in-suite coverage) + # ═══════════════════════════════════════════════════════════════════════════ + "transport:streamable-http:stateful": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip (initialize, tool calls, tool errors) works through the " + "streamable HTTP framing in its default stateful SSE-response mode." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:json-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + transports=("streamable-http",), + ), + "transport:streamable-http:stateless": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip works in stateless mode, where every request is served by a " + "fresh transport with no session id." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:notifications": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "Notifications emitted during a request are delivered on that request's SSE stream and reach " + "the client's callbacks, in order, before the response." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:stateless-restrictions": Requirement( + source="sdk", + behavior=( + "A handler that attempts a server-initiated request in stateless mode fails with an error " + "result, because there is no session to call back through." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:unrelated-messages": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-to-client message that is not related to an in-flight request is routed to the " + "standalone GET stream and delivered to the client listening on it, not to any request's " + "own stream." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:resumability": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + transports=("streamable-http",), + ), + "transport:streamable-http:origin-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior="Requests with an invalid Origin header are rejected with 403 before reaching the session.", + transports=("streamable-http",), + ), + "transport:sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " + "requests, with server messages delivered on the SSE stream." + ), + transports=("sse",), + ), + "transport:sse:endpoint-event": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "Opening the SSE stream delivers an `endpoint` event naming the message-POST URL and a fresh " + "session identifier; the server registers the session before the event is sent and releases it " + "when the stream disconnects." + ), + transports=("sse",), + ), + "transport:sse:post:session-routing": Requirement( + source="sdk", + behavior=( + "A POST to the SSE message endpoint that names no session id, a malformed session id, or an " + "unknown session id is rejected (400/400/404) instead of being forwarded." + ), + transports=("sse",), + ), + "transport:stdio": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "A Client connected to a real SDK Server over stdio initializes, calls a tool with arguments, " + "and receives notifications and results over the child process's stdin/stdout." + ), + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: session lifecycle + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:session:cors-expose": Requirement( + source="sdk", + behavior="CORS configuration exposes the Mcp-Session-Id header so browser clients can read it.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: CORS configuration is left to the hosting ASGI application.", + ), + "hosting:session:create": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "An initialize POST without a session id creates a session and returns Mcp-Session-Id in the " + "response headers." + ), + transports=("streamable-http",), + ), + "hosting:session:delete": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="DELETE with a valid Mcp-Session-Id terminates the session and removes its transport.", + transports=("streamable-http",), + ), + "hosting:session:id-charset": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Generated Mcp-Session-Id values contain only visible ASCII characters.", + transports=("streamable-http",), + ), + "hosting:session:isolation": Requirement( + source="sdk", + behavior="Each session gets its own server instance; closing one session does not affect others.", + transports=("streamable-http",), + ), + "hosting:session:missing-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A non-initialize POST without Mcp-Session-Id in stateful mode returns 400.", + transports=("streamable-http",), + ), + "hosting:session:post-termination-404": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "After a session is terminated, any further request carrying that session ID is answered with " + "404 Not Found." + ), + transports=("streamable-http",), + ), + "hosting:session:reinitialize": Requirement( + source="sdk", + behavior="A second initialize on an already-initialized session transport is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport forwards a second initialize carrying the existing session ID to the running " + "server, which answers it as a fresh handshake; nothing rejects re-initialization." + ), + ), + ), + "hosting:session:reuse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST carrying a valid Mcp-Session-Id routes to that session's transport with state preserved.", + transports=("streamable-http",), + ), + "hosting:session:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST, GET, or DELETE with an unknown Mcp-Session-Id returns 404.", + transports=("streamable-http",), + ), + "hosting:stateless:concurrent-clients": Requirement( + source="sdk", + behavior="Multiple independent clients can connect to a stateless server concurrently.", + transports=("streamable-http",), + ), + "hosting:stateless:no-reuse": Requirement( + source="sdk", + behavior="A stateless per-request transport cannot be reused for a second request.", + transports=("streamable-http",), + ), + "hosting:stateless:no-session-id": Requirement( + source="sdk", + behavior="In stateless mode no Mcp-Session-Id is emitted and no session validation is performed.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: auth + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:auth:as-router": Requirement( + source="sdk", + behavior=( + "The authorization-server routes expose the authorize, token, and registration endpoints " + "(and revocation when supported)." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:aud-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#access-token-usage", + behavior="The resource server validates that the token audience matches its resource identifier.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "hosting:auth:authinfo-propagates": Requirement( + source="sdk", + behavior="A valid token's auth info is exposed to request handlers.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:expired-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="An expired token returns 401 invalid_token.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:invalid-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:metadata-endpoints": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The MCP server publishes protected-resource metadata at its well-known endpoint, and the " + "authorization server (which the SDK can also host) publishes authorization-server metadata " + "at its own." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:missing-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "A request without an Authorization header is rejected with 401; the WWW-Authenticate header " + "carries resource_metadata (one of the spec's two permitted discovery mechanisms)." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:prm:authorization-servers-field": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The protected-resource metadata document includes an authorization_servers array with at least one entry." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "hosting:auth:scope-403": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#runtime-insufficient-scope-errors", + behavior=( + "A token lacking a required scope returns 403 with WWW-Authenticate carrying " + "insufficient_scope, the required scope, and resource_metadata." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: resumability + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:resume:bad-event-id": Requirement( + source="sdk", + behavior="A Last-Event-ID that cannot be mapped to a stream is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The replay path returns an empty SSE stream rather than rejecting an unknown " + "Last-Event-ID; the client cannot tell an unknown ID apart from a stream with no missed " + "events." + ), + ), + ), + "hosting:resume:buffered-replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Notifications emitted while no client is connected are replayed in order on reconnect.", + transports=("streamable-http",), + ), + "hosting:resume:close-stream": Requirement( + source="sdk", + behavior="Handlers can close an SSE stream cleanly when an event store is configured.", + transports=("streamable-http",), + ), + "hosting:resume:event-ids": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="With an event store configured, every SSE event carries an id field.", + transports=("streamable-http",), + ), + "hosting:resume:priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A server-initiated SSE stream begins with a priming event carrying an event ID and an empty " + "data field; a server that closes the connection before terminating the stream sends an SSE " + "retry field first." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The retry hint is attached to the priming event itself rather than sent as a separate " + "event before the connection closes, and a priming event is only sent when an event store " + "is configured and the negotiated protocol version is at least 2025-11-25." + ), + ), + ), + "hosting:resume:replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="GET with Last-Event-ID replays stored events for that stream after the given id.", + transports=("streamable-http",), + ), + "hosting:resume:stream-scoped": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Replay via Last-Event-ID returns only messages from the stream that event id belongs to.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: HTTP semantics + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:http:accept-406": Requirement( + source="sdk", + behavior="A request whose Accept header does not allow the response representation returns 406.", + transports=("streamable-http",), + ), + "hosting:http:batch": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body is a single JSON-RPC message; batched arrays are rejected for protocol revisions " + "that forbid them." + ), + transports=("streamable-http",), + ), + "hosting:http:content-type-415": Requirement( + source="sdk", + behavior="A POST with a Content-Type other than application/json returns 415.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport-security middleware rejects a non-JSON Content-Type with 400 'Invalid " + "Content-Type header' before the request reaches the transport, so the transport's own 415 " + "path is unreachable through any public entry point." + ), + ), + ), + "hosting:http:disconnect-not-cancel": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A client connection drop during an in-flight request does not cancel the server-side " + "handler; the request continues and its result remains retrievable." + ), + transports=("streamable-http",), + ), + "hosting:http:dns-rebinding": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior=( + "The Origin header is validated on every incoming connection; a request with an invalid " + "Origin is rejected with 403 Forbidden." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The spec's Origin validation is an unconditional MUST; the SDK enables it only when the " + "host is a localhost address or explicit TransportSecuritySettings are passed (with no " + "settings, no Origin validation runs), and additionally validates the Host header " + "(returning 421 on mismatch), which the spec does not require." + ), + ), + ), + "hosting:http:json-response-mode": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="With JSON response mode enabled, POST returns application/json instead of SSE.", + transports=("streamable-http",), + ), + "hosting:http:method-405": Requirement( + source="sdk", + behavior="An unsupported HTTP method on the MCP endpoint returns 405.", + transports=("streamable-http",), + ), + "hosting:http:no-broadcast": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#multiple-connections", + behavior=( + "When multiple SSE streams are open for a session, each server-originated message is sent on " + "exactly one stream, never duplicated." + ), + transports=("streamable-http",), + ), + "hosting:http:notifications-202": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A POST containing only notifications or responses returns 202 with no body.", + transports=("streamable-http",), + ), + "hosting:http:onerror": Requirement( + source="sdk", + behavior="Transport-level rejections are reported through an error callback on the server transport.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: the server transport has no error callback; rejections are logged.", + ), + "hosting:http:parse-error-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body that is not valid JSON or not a valid JSON-RPC message is rejected with HTTP 400; " + "the body may carry a JSON-RPC error response (the SDK sends a Parse error body)." + ), + transports=("streamable-http",), + ), + "hosting:http:protocol-version-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior="An invalid or unsupported MCP-Protocol-Version header returns 400 Bad Request.", + transports=("streamable-http",), + ), + "hosting:http:protocol-version-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "When no MCP-Protocol-Version header is received and the version cannot be determined another " + "way, the server assumes protocol version 2025-03-26." + ), + transports=("streamable-http",), + ), + "hosting:http:response-same-connection": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A response is delivered on the SSE stream opened by the POST that carried its request (or " + "that stream's resumed continuation), not on an unrelated stream." + ), + transports=("streamable-http",), + ), + "hosting:http:second-sse-rejected": Requirement( + source="sdk", + behavior="A second concurrent standalone GET SSE stream on the same session is rejected.", + transports=("streamable-http",), + ), + "hosting:http:sse-close-after-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="The server terminates a POST-initiated SSE stream after writing the JSON-RPC response.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="GET opens a standalone SSE stream that receives server-initiated messages.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse-no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior=( + "The standalone GET SSE stream carries server requests and notifications but never a JSON-RPC " + "response, except when resuming a prior request stream." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client transport: streamable HTTP + # ═══════════════════════════════════════════════════════════════════════════ + "client-transport:http:404-surfaces": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "A 404 in response to a request carrying a session ID makes the client start a new session " + "with a fresh InitializeRequest and no session ID attached." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client surfaces the 404 as an error to the caller instead of re-initializing a new " + "session; the spec's MUST is not satisfied." + ), + ), + ), + "client-transport:http:accept-header-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="The client GET to the MCP endpoint includes an Accept header listing text/event-stream.", + transports=("streamable-http",), + ), + "client-transport:http:accept-header-post": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "Every client POST to the MCP endpoint includes an Accept header listing both application/json " + "and text/event-stream." + ), + transports=("streamable-http",), + ), + "client-transport:http:concurrent-streams": Requirement( + source="sdk", + behavior="Multiple concurrent POST-initiated SSE streams each deliver their response to the right caller.", + transports=("streamable-http",), + ), + "client-transport:http:custom-client": Requirement( + source="sdk", + behavior=( + "A caller-supplied HTTP client (and its event hooks and headers) is used for all MCP traffic, " + "including auth flows." + ), + transports=("streamable-http",), + ), + "client-transport:http:custom-headers": Requirement( + source="sdk", + behavior="Caller-supplied headers are sent on every POST, GET, and DELETE to the MCP endpoint.", + transports=("streamable-http",), + ), + "client-transport:http:json-response-parsed": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A Content-Type application/json response is parsed as a single JSON-RPC message.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-close": Requirement( + source="sdk", + behavior="After the transport is closed, no further reconnection attempts are scheduled.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-response": Requirement( + source="sdk", + behavior="A POST-initiated stream that already delivered its response is not reconnected when it closes.", + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-header": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "After initialization, the client sends the negotiated MCP-Protocol-Version header on every " + "subsequent HTTP request." + ), + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-stored": Requirement( + source="sdk", + behavior=( + "The client transport stores the negotiated protocol version and sends it on every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A standalone GET SSE stream that errors is reconnected with the Last-Event-ID of the last received event." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here: the standalone GET stream emits no priming event or retry hint, so " + "the client's reconnection path always sleeps the hard-coded 1 s default; a deterministic " + "in-process test would inject real-time delay or require an SDK change. The POST-stream " + "reconnection path is covered by client-transport:http:reconnect-post-priming." + ), + ), + "client-transport:http:reconnect-post-priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST-initiated SSE stream that errors before delivering its response is reconnected only " + "if a priming event (an event carrying an ID) was received on it." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-retry-value": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="Reconnection delay honours the server-provided SSE retry value when one was sent.", + transports=("streamable-http",), + ), + "client-transport:http:resume-stream-api": Requirement( + source="sdk", + behavior=( + "The client can capture a resumption token, reconnect with the same session id, and receive " + "the notifications it missed." + ), + transports=("streamable-http",), + ), + "client-transport:http:session-stored": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "The Mcp-Session-Id returned by initialize is stored by the client transport and sent on " + "every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:sse-405-tolerated": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="Opening the standalone GET SSE stream tolerates a 405 response without failing the connection.", + transports=("streamable-http",), + ), + "client-transport:http:terminate-405-ok": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Session termination succeeds without error if the server answers 405 (termination unsupported).", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client auth + # ═══════════════════════════════════════════════════════════════════════════ + "client-auth:401-after-auth-throws": Requirement( + source="sdk", + behavior=( + "If the server still returns 401 after a successful authorization, the client fails instead of looping." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:401-triggers-flow": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior="A 401 on a request triggers the OAuth authorization flow once.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:403-scope-upgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#step-up-authorization-flow", + behavior=( + "A 403 with WWW-Authenticate triggers a scope-upgrade authorization attempt; repeated 403s do not loop." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:as-metadata-discovery:priority-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client discovers authorization-server metadata by trying, in order, the OAuth " + "path-inserted, OIDC path-inserted, and OIDC path-appended well-known URLs (with the " + "root-path forms when the issuer URL has no path)." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:bearer-header:every-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", + behavior=( + "Once authorized, the client sends the bearer token in the Authorization header on every HTTP " + "request to the MCP server, never in the query string." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:cimd": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#client-id-metadata-documents", + behavior="The client can use a client-ID metadata document URL as its OAuth client_id instead of registration.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: client-ID metadata documents are not supported.", + ), + "client-auth:client-credentials": Requirement( + source="sdk", + behavior=( + "A client-credentials provider obtains a token without user interaction and the resulting " + "bearer token authorizes subsequent requests." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "client-auth:dcr": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#dynamic-client-registration", + behavior=( + "The client performs dynamic client registration against the authorization server when no " + "client_id is preconfigured." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:invalid-client-clears-all": Requirement( + source="sdk", + behavior=( + "An invalid-client or unauthorized-client error during authorization invalidates all stored credentials." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:invalid-grant-clears-tokens": Requirement( + source="sdk", + behavior="An invalid-grant error during authorization invalidates only the stored tokens.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:pkce:refuse-if-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The client refuses to proceed when the authorization server's metadata does not include " + "code_challenge_methods_supported, since PKCE support cannot be verified." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:pkce:s256": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The authorization request includes a PKCE S256 code challenge and the token request includes " + "the matching verifier." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:pre-registration": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#preregistration", + behavior=( + "A client with statically preconfigured credentials skips dynamic registration and uses them directly." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:private-key-jwt": Requirement( + source="sdk", + behavior="The client can authenticate the client-credentials grant with a signed JWT assertion.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: JWT-assertion client authentication is not supported.", + ), + "client-auth:prm-discovery:fallback-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "The client uses resource_metadata from WWW-Authenticate when present, then falls back to the " + "well-known protected-resource locations in the documented order." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:prm-resource-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The client refuses to proceed when the protected-resource metadata's resource field does not " + "match the server URL it is connecting to." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:resource-parameter": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#resource-parameter-implementation", + behavior=( + "The client includes the canonical server URI as the resource parameter in both the " + "authorization request and the token request." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." + ), + ), + "client-auth:scope-selection:priority": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", + behavior=( + "The client selects the requested scope from WWW-Authenticate when present, then from the " + "protected-resource metadata, and otherwise omits scope." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:state:verify": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "A state parameter is included in the authorization URL, and authorization results with a " + "missing or mismatched state are discarded." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:token-endpoint-auth-method": Requirement( + source="sdk", + behavior="The client authenticates to the token endpoint using the auth method established at registration.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:token-provenance": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior=( + "The client sends the MCP server only tokens issued by that server's authorization server, " + "never tokens obtained elsewhere." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # stdio transport + # ═══════════════════════════════════════════════════════════════════════════ + "transport:stdio:clean-shutdown": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#shutdown", + behavior="Closing the client transport closes the child process's stdin and the server exits cleanly.", + transports=("stdio",), + ), + "transport:stdio:stream-purity": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "Nothing that is not a valid MCP message is written to the server's stdout, and nothing that " + "is not a valid MCP message is written to its stdin." + ), + transports=("stdio",), + divergence=Divergence( + note=( + "stdio_server's own writes satisfy this, but it does not redirect or guard sys.stdout: " + "handler code that calls print() writes directly to the protocol stream and corrupts the " + "framing. The spec MUST is satisfied only as long as application code behaves." + ), + ), + ), + "transport:stdio:no-embedded-newlines": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="Serialized JSON-RPC messages on stdio contain no embedded newlines; one message per line.", + transports=("stdio",), + ), + "transport:stdio:shutdown-escalation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#stdio", + behavior=( + "If the server process does not exit after stdin is closed, the client transport terminates " + "it (and kills it if still alive) after a grace period." + ), + transports=("stdio",), + deferred=( + "Not yet covered here: a server that ignores stdin close takes the full " + "PROCESS_TERMINATION_TIMEOUT (2.0 s) grace period plus up to a further 2.0 s for " + "SIGTERM/SIGKILL escalation; a robust test of that path is real-time-bound and the constant " + "is module-level (no public override). Covered by tests/client/test_stdio.py." + ), + ), + "transport:stdio:stderr-passthrough": Requirement( + source="sdk", + behavior="Server stderr is available to the client and is not consumed by the transport.", + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Composite end-to-end flows + # ═══════════════════════════════════════════════════════════════════════════ + "flow:compat:dual-transport-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A single server instance can serve streamable HTTP and the legacy SSE transport " + "concurrently; clients on either transport can call the same tools." + ), + transports=("streamable-http", "sse"), + ), + "flow:compat:streamable-then-sse-fallback": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "When a streamable HTTP initialize fails with 400, 404, or 405, falling back to the legacy " + "SSE client transport against the same server connects successfully." + ), + transports=("streamable-http", "sse"), + divergence=Divergence( + note=( + "The SDK provides no automatic streamable-HTTP-to-SSE client fallback; the spec's " + "client-side SHOULD is left to the application to compose from streamable_http_client " + "and sse_client. Both halves are independently proven by the matrix." + ), + ), + deferred=( + "A demonstration test would only re-prove what the matrix already covers (an SSE-only " + "server is reachable via sse_client; an unmounted route returns 404), with the application " + "doing the fallback in between rather than the SDK." + ), + ), + "flow:elicitation:multi-step-form": Requirement( + source="sdk", + behavior=( + "A single tool handler issues sequential elicitations; an accept on one step feeds the next, " + "and a decline or cancel at any step short-circuits to a final result." + ), + ), + "flow:elicitation:url-at-session-init": Requirement( + source="sdk", + behavior=( + "The server can issue a URL-mode elicitation over the standalone GET stream immediately after " + "session initialization, before any client request." + ), + transports=("streamable-http",), + deferred=( + "No public per-session post-initialization hook exists on either server flavour " + "(Server.lifespan runs at server startup, not per session; ServerSession handles the " + "initialized notification internally with no callback). Driving 'before any client " + "request' deterministically would also require knowing the standalone GET stream is " + "established, which has no synchronization signal." + ), + ), + "flow:elicitation:url-required-then-retry": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A tool call rejected with the URL-elicitation-required error can be retried successfully " + "after the client completes the URL flow and the server announces completion." + ), + ), + "flow:multi-client:stateful-isolation": Requirement( + source="sdk", + behavior=( + "Independent clients connected to one stateful server each receive a distinct session and " + "only the notifications produced by their own requests." + ), + transports=("streamable-http",), + ), + "flow:oauth:authorization-code-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "Connecting to a protected server walks the authorization-code flow end to end: the first " + "attempt requires authorization, the code is exchanged, and a subsequent connection succeeds." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "flow:resume:tool-call-resumption-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A tool call interrupted mid-stream is transparently resumed by the client transport using " + "the last-seen event id, delivering only the remaining notifications and the final result." + ), + transports=("streamable-http",), + ), + "flow:session:terminate-then-reconnect": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=("After terminating a session, a fresh connection obtains a new session id and operations succeed."), + transports=("streamable-http",), + ), + "flow:tool-result:resource-link-follow": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior=( + "A resource_link returned by a tool call can be followed with resources/read on the linked " + "URI to retrieve the referenced contents." + ), + ), +} + + +def requirement(requirement_id: str) -> Callable[[_TestFn], _TestFn]: + """Mark a test as exercising a requirement from :data:`REQUIREMENTS`. + + Applies the `requirement` pytest marker and records the coverage link checked by + `test_coverage.py`. Unknown IDs fail at import time so a typo surfaces as a collection + error on the offending test, not as a missing-coverage report later. + """ + if requirement_id not in REQUIREMENTS: + raise KeyError(f"Unknown requirement id {requirement_id!r}: add it to REQUIREMENTS in {__name__}") + + def apply(test_fn: _TestFn) -> _TestFn: + covered_by(requirement_id).append(f"{test_fn.__module__}.{test_fn.__qualname__}") + return pytest.mark.requirement(requirement_id)(test_fn) + + return apply + + +_COVERAGE: dict[str, list[str]] = {} + + +def covered_by(requirement_id: str) -> list[str]: + """Return the (mutable) list of test names recorded as exercising `requirement_id`.""" + return _COVERAGE.setdefault(requirement_id, []) diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py new file mode 100644 index 0000000000..c2ace45077 --- /dev/null +++ b/tests/interaction/conftest.py @@ -0,0 +1,23 @@ +"""Shared fixtures for the interaction suite.""" + +import pytest + +from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http + +_FACTORIES: dict[str, Connect] = { + "in-memory": connect_in_memory, + "streamable-http": connect_over_streamable_http, + "sse": connect_over_sse, +} + + +@pytest.fixture(params=sorted(_FACTORIES)) +def connect(request: pytest.FixtureRequest) -> Connect: + """The transport-parametrized connection factory: a test using it runs once per transport. + + Tests that are tied to one transport (the wire-recording tests, the bare-ClientSession tests, + the transport-specific tests under transports/) do not use this fixture and connect directly. + """ + transport_name = request.param + assert isinstance(transport_name, str) + return _FACTORIES[transport_name] diff --git a/tests/interaction/lowlevel/__init__.py b/tests/interaction/lowlevel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py new file mode 100644 index 0000000000..f39b2014cf --- /dev/null +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -0,0 +1,232 @@ +"""Cancellation interactions against the low-level Server, driven through the public Client API. + +There is no client-side cancellation API: cancelling means sending a CancelledNotification +carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so +these tests capture the id from inside the blocked handler before cancelling. The handler blocks +on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + EmptyResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + PingRequest, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:cancel:in-flight") +@requirement("protocol:cancel:handler-abort-propagates") +async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: + """Cancelling an in-flight request interrupts its handler and fails the pending call. + + The server answers the cancelled request with an error response (the spec says it should + not respond at all; see the divergence note on the requirement), so the caller's pending + request raises rather than hanging. + """ + started = anyio.Event() + handler_cancelled = anyio.Event() + request_ids: list[types.RequestId] = [] + errors: list[ErrorData] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + try: + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + errors.append(exc_info.value.error) + + task_group.start_soon(call_and_capture_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification( + params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") + ) + ) + + await handler_cancelled.wait() + + assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) + + +@requirement("protocol:cancel:server-survives") +async def test_session_serves_requests_after_cancellation(connect: Connect) -> None: + """A request cancelled mid-flight does not poison the session: the next request succeeds.""" + started = anyio.Event() + request_ids: list[types.RequestId] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + await anyio.Event().wait() # blocks until cancelled + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_swallow_cancellation_error() -> None: + with pytest.raises(MCPError): + await client.call_tool("block", {}) + + task_group.start_soon(call_and_swallow_cancellation_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) + ) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:cancel:unknown-id-ignored") +async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None: + """A cancellation referencing a request id that is not in flight is ignored without error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="unbothered")]) + + server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) + ) + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) + + +@requirement("protocol:cancel:late-response-ignored") +async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: + """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. + + The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; + that is the same client-side code path as any response with an unknown id, and that form is + deterministic to test without depending on the cancellation API the SDK does not yet provide. + See the divergence note on the requirement. + + A real Server cannot be made to answer with a fabricated id, so the test plays the server's + side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The + other tests in this file run over the transport matrix; this one is in-memory only because the + scripted-peer mechanism is the in-memory stream pair, not because the behaviour is + transport-specific. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: + return SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + assert init.message.method == "initialize" + await server_write.send( + respond( + init.message.id, + InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="scripted", version="0.0.1"), + ), + ) + ) + + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + + ping = await server_read.receive() + assert isinstance(ping, SessionMessage) + assert isinstance(ping.message, JSONRPCRequest) + assert ping.message.method == "ping" + # First answer with a fabricated id that matches nothing in flight, then the real id. + await server_write.send(respond(9999, EmptyResult())) + await server_write.send(respond(ping.message.id, EmptyResult())) + + incoming: list[IncomingMessage] = [] + + async def message_handler(message: IncomingMessage) -> None: + incoming.append(message) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(scripted_server) + async with ClientSession(client_read, client_write, message_handler=message_handler) as session: + with anyio.fail_after(5): + await session.initialize() + pong = await session.send_request(PingRequest(), EmptyResult) + + assert pong == snapshot(EmptyResult()) + assert len(incoming) == 1 + assert isinstance(incoming[0], RuntimeError) + # The full message embeds the response object's repr; only the prefix is stable. + assert str(incoming[0]).startswith("Received response with an unknown request ID:") diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py new file mode 100644 index 0000000000..6a35404df3 --- /dev/null +++ b/tests/interaction/lowlevel/test_completion.py @@ -0,0 +1,131 @@ +"""Completion interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + METHOD_NOT_FOUND, + CompleteResult, + Completion, + ErrorData, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("completion:prompt-arg") +@requirement("completion:result-shape") +async def test_complete_prompt_argument(connect: Connect) -> None: + """Completing a prompt argument delivers the ref, argument name, and current value to the handler. + + The returned values are filtered by the argument's value, proving the value reached the handler. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + assert params.ref.name == "code_review" + assert params.argument.name == "language" + candidates = ["python", "pytorch", "ruby"] + matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] + return CompleteResult(completion=Completion(values=matches, total=len(matches), has_more=False)) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + PromptReference(name="code_review"), argument={"name": "language", "value": "py"} + ) + + assert result == snapshot( + CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, has_more=False)) + ) + + +@requirement("completion:resource-template-arg") +async def test_complete_resource_template_variable(connect: Connect) -> None: + """Completing a URI template variable delivers the template URI and variable name to the handler.""" + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "github://repos/{owner}/{repo}" + assert params.argument.name == "owner" + return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "owner", "value": "model"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol"]))) + + +@requirement("completion:context-arguments") +async def test_complete_receives_context_arguments(connect: Connect) -> None: + """Previously-resolved arguments passed as completion context reach the handler. + + The returned value is derived from the context, proving it arrived. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert params.argument.name == "repo" + assert params.context is not None + assert params.context.arguments is not None + return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol/python-sdk"]))) + + +@requirement("completion:error:invalid-ref") +async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params(connect: Connect) -> None: + """completion/complete with a ref naming an unknown prompt is answered with -32602 Invalid params. + + The lowlevel server does not validate refs itself (it has no prompt/template registry to check + against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended + way to do it. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}") + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="ghost"), argument={"name": "x", "value": ""}) + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("completion:complete:not-supported") +@requirement("protocol:error:method-not-found") +async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None: + """A server with no completion handler advertises no completions capability and rejects the request.""" + server = Server("incomplete") + + async with connect(server) as client: + assert client.initialize_result.capabilities.completions is None + + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py new file mode 100644 index 0000000000..83a77592a9 --- /dev/null +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -0,0 +1,661 @@ +"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API. + +The final test plays the server's side of the wire by hand to issue an elicitation request with no +mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + ElicitCompleteNotification, + ElicitCompleteNotificationParams, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +REQUESTED_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], +} + + +@requirement("elicitation:form:action:accept") +@requirement("elicitation:form:basic") +@requirement("tools:call:elicitation-roundtrip") +async def test_elicit_form_accepted_content_returns_to_handler(connect: Connect) -> None: + """An accepted form elicitation returns the user's content to the requesting handler. + + The tool reports the action as text and the received content as structured content, proving + the client's answer made it back into the tool's own result. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Choose a username.", + requested_schema={ + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept")], + structured_content={"username": "ada", "newsletter": True}, + ) + ) + + +@requirement("elicitation:form:action:decline") +async def test_elicit_form_decline_returns_no_content(connect: Connect) -> None: + """A declined form elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:form:action:cancel") +async def test_elicit_form_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled form elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:form:not-supported") +@requirement("elicitation:capability:server-respects-mode") +async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: + """Eliciting from a client that configured no elicitation callback fails with an error. + + The client's default callback answers with an Invalid request error, which the server-side + elicit call raises as an MCPError; the tool reports the code and message it caught. The spec + requires -32602 for an undeclared mode (see the divergence note on the requirement). The + request reaching the client also shows the server does not check the client's declared + elicitation capability before sending (see the divergence on `server-respects-mode`). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask" + try: + await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # elicit_form cannot succeed without a client callback + + server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + + +@requirement("elicitation:url:action:accept-no-content") +@requirement("elicitation:url:basic") +async def test_elicit_url_delivers_url_and_returns_accept_without_content(connect: Connect) -> None: + """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it + returns the action with no content. + + Accept means the user agreed to visit the URL, not that the out-of-band interaction finished, + so there is never form content to return. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert received == snapshot( + [ + ElicitRequestURLParams( + _meta={}, + message="Authorize access to your calendar.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="accept content=None")])) + + +@requirement("elicitation:url:decline") +async def test_elicit_url_decline_returns_no_content(connect: Connect) -> None: + """A declined URL elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:url:cancel") +async def test_elicit_url_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled URL elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:url:complete-notification") +async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client(connect: Connect) -> None: + """After a URL elicitation finishes, the server announces it with a notification carrying the same id. + + The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user + agrees to visit the URL, the out-of-band interaction finishes, and the server emits + elicitation/complete so the client can correlate the completion with the elicitation it + accepted earlier. Both messages arrive before the tool call returns, so a plain collected + list needs no synchronisation. + """ + elicitation_id = "auth-001" + elicited_ids: list[str] = [] + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="link_account", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "link_account" + answer = await ctx.session.elicit_url( + "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id + ) + assert answer.action == "accept" + await ctx.session.send_elicit_complete(elicitation_id) + return CallToolResult(content=[TextContent(text="linked")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestURLParams) + elicited_ids.append(params.elicitation_id) + return ElicitResult(action="accept") + + async with connect(server, message_handler=collect, elicitation_callback=answer_url) as client: + await client.call_tool("link_account", {}) + + # The completion notification refers to the same elicitation the client accepted. + assert elicited_ids == [elicitation_id] + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="auth-001"))] + ) + + +@requirement("elicitation:url:required-error") +async def test_url_elicitation_required_error_carries_pending_elicitations(connect: Connect) -> None: + """A request that cannot proceed until a URL interaction completes is rejected with error -32042. + + This is the non-interactive alternative to elicit_url: instead of asking and waiting, the + handler rejects the whole request and lists the required URL elicitations in the error data. + The client is expected to present those URLs, wait for the matching elicitation/complete + notifications, and retry the original request. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + server = Server("authorizer", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("elicitation:form:schema:primitives") +@requirement("elicitation:form:schema:enum-variants") +async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the_callback_as_sent( + connect: Connect, +) -> None: + """A requested schema covering every spec-listed property kind is delivered to the callback unchanged. + + One schema with one property per kind: a formatted string, an integer with bounds, a number, + a boolean, a plain enum, a oneOf-const titled enum, and a multi-select array-of-enum. The + callback observing the same schema as the handler sent proves both the primitive coverage and + the enum-variant coverage in one snapshot. + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email", "title": "Email", "description": "Contact address."}, + "age": {"type": "integer", "minimum": 0, "maximum": 150}, + "score": {"type": "number"}, + "subscribe": {"type": "boolean", "default": False}, + "tier": {"type": "string", "enum": ["free", "pro", "team"]}, + "region": { + "oneOf": [ + {"const": "eu", "title": "Europe"}, + {"const": "na", "title": "North America"}, + ], + }, + "channels": {"type": "array", "items": {"type": "string", "enum": ["email", "sms", "push"]}}, + }, + "required": ["email"], + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="onboard", description="Onboard the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + answer = await ctx.session.elicit_form("Tell us about yourself.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"email": "ada@example.com"}) + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("onboard", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:schema:restricted-subset") +async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: Connect) -> None: + """A requested schema with nested-object and array-of-object properties passes through unchanged. + + The spec restricts form-mode requested schemas to flat objects with primitive-typed properties; + this test pins that the SDK does not enforce that restriction on either side (see the + divergence on the requirement). + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + }, + "contacts": { + "type": "array", + "items": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "profile" + answer = await ctx.session.elicit_form("Profile details.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("profile", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:response-validation") +async def test_accepted_elicitation_content_that_violates_the_schema_reaches_the_handler_unchanged( + connect: Connect, +) -> None: + """Accepted form content that contradicts the requested schema is delivered to the handler unchanged. + + The schema requires a string `name`; the callback answers with a wrong-type value and an extra + field. Nothing on either side validates the response against the schema (see the divergence on + the requirement), so the handler observes exactly what the callback sent. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form( + "Choose a name.", + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content={"name": 42, "extra": "field"}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="accept")], structured_content={"name": 42, "extra": "field"}) + ) + + +@requirement("elicitation:url:complete-unknown-ignored") +async def test_elicitation_complete_for_an_unknown_id_is_received_without_error(connect: Connect) -> None: + """An elicitation/complete for an id the client never elicited is delivered and does not fail anything. + + No URL elicitation precedes the notification; the client neither tracks elicitation ids nor + rejects unknown ones, so the call completes normally and the message handler observes the + notification as-is. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="noop", description="Send a stray complete.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "noop" + await ctx.session.send_elicit_complete("never-elicited") + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(server, message_handler=collect) as client: + result = await client.call_tool("noop", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="never-elicited"))] + ) + + +@requirement("elicitation:form:mode-omitted-default") +async def test_a_mode_less_elicitation_request_is_treated_as_form_mode() -> None: + """An elicitation/create request with no mode field reaches the client callback as form-mode. + + The typed server API always serializes a mode (`elicit_form` writes 'form', `elicit_url` writes + 'url'), so this test plays the server's side of the wire by hand to send a request body without + one. Reserve this pattern for behaviour the typed server API cannot produce. + """ + received: list[types.ElicitRequestParams] = [] + answered = anyio.Event() + server_received: list[JSONRPCMessage] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={}) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + ) + ) + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(scripted_server) + async with ClientSession(client_read, client_write, elicitation_callback=answer_form) as session: + with anyio.fail_after(5): + await session.initialize() + await answered.wait() + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta=None, + message="Legacy ask.", + requested_schema={"type": "object", "properties": {}}, + ) + ] + ) + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].mode == "form" + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py new file mode 100644 index 0000000000..8ff9dd4f1d --- /dev/null +++ b/tests/interaction/lowlevel/test_flows.py @@ -0,0 +1,193 @@ +"""Composed multi-feature flows against the low-level Server, driven through the public Client API. + +Each test reads as the scenario it proves: the steps run top to bottom in the order a real client +would perform them, composing two or more feature areas (a tool call followed by a resource read; +a chain of elicitations inside one tool call; the full URL-elicitation-required retry loop). The +individual features are pinned by their own tests; these prove they compose. +""" + +from collections.abc import Awaitable, Callable + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitCompleteNotification, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ListToolsResult, + ReadResourceResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ListToolsHandler = Callable[ + [ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult] +] + + +def _list_tools(*names: str) -> ListToolsHandler: + """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name=name, input_schema={"type": "object"}) for name in names]) + + return list_tools + + +@requirement("flow:tool-result:resource-link-follow") +async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(connect: Connect) -> None: + """A tool returns a resource_link; reading that link's URI returns the referenced contents. + + Steps: (1) call the tool, (2) extract the link from its content, (3) read_resource on the + link's URI, (4) the read result carries the linked contents. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "generate" + return CallToolResult(content=[ResourceLink(uri="file:///report.txt", name="report")]) + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + assert str(params.uri) == "file:///report.txt" + return ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + + server = Server( + "linker", on_list_tools=_list_tools("generate"), on_call_tool=call_tool, on_read_resource=read_resource + ) + + async with connect(server) as client: + called = await client.call_tool("generate", {}) + link = called.content[0] + assert isinstance(link, ResourceLink) + read = await client.read_resource(link.uri) + + assert called == snapshot(CallToolResult(content=[ResourceLink(name="report", uri="file:///report.txt")])) + assert read == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + ) + + +@requirement("flow:elicitation:multi-step-form") +async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forward(connect: Connect) -> None: + """Sequential form elicitations inside one tool call: each accepted answer feeds the next step. + + Steps: (1) call the tool, (2) the handler issues a step-one form elicitation that the client + accepts with content, (3) the handler issues a step-two elicitation whose message references + the step-one answer, (4) the client accepts step two, (5) the tool result summarises both + answers. The callback is invoked exactly twice with the expected messages and schemas. The + short-circuit on decline is the application's choice (proven separately by the per-action + elicitation tests); what this flow pins is that the chain itself works end to end. + """ + received: list[ElicitRequestFormParams] = [] + answers: list[dict[str, str | int | float | bool | list[str] | None]] = [{"name": "ada"}, {"age": 37}] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + first = await ctx.session.elicit_form( + "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} + ) + assert first.action == "accept" and first.content is not None + second = await ctx.session.elicit_form( + f"Step 2: confirm age for {first.content['name']}.", + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ) + assert second.action == "accept" and second.content is not None + return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) + + server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) + + async def answer(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestFormParams) + received.append(params) + return ElicitResult(action="accept", content=answers[len(received) - 1]) + + async with connect(server, elicitation_callback=answer) as client: + result = await client.call_tool("onboard", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ada is 37")])) + assert [(p.message, p.requested_schema) for p in received] == snapshot( + [ + ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), + ("Step 2: confirm age for ada.", {"type": "object", "properties": {"age": {"type": "integer"}}}), + ] + ) + + +@requirement("flow:elicitation:url-required-then-retry") +async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_after_completion( + connect: Connect, +) -> None: + """The full URL-elicitation-required retry loop: -32042, completion announced, retry succeeds. + + Steps: (1) the first call is rejected with -32042 carrying the required URL elicitation in + its error data, (2) the client extracts the elicitation id from the error, (3) the server + announces completion via the elicitation/complete notification (driven via the captured + session, the same way a real out-of-band callback would reach a held session reference), + (4) the client observes the matching completion notification and retries, (5) the retry + succeeds. The handler distinguishes the two calls by a closure flag the test flips between + them; the test waits on the completion notification with an event so the retry only happens + after the announcement has arrived. + """ + elicitation_id = "auth-001" + authorised: list[bool] = [False] + captured: list[ServerSession] = [] + completed = anyio.Event() + notifications: list[ElicitCompleteNotification] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + captured.append(ctx.session) + if not authorised[0]: + # The log line gives the message handler a non-completion notification, so the test's + # filtering branch is exercised in both directions and the wait remains specific. + await ctx.session.send_log_message(level="warning", data="authorisation required", logger="gate") + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorize file access.", + url="https://example.com/oauth/authorize", + elicitation_id=elicitation_id, + ) + ] + ) + return CallToolResult(content=[TextContent(text="contents")]) + + server = Server("gatekeeper", on_list_tools=_list_tools("read_files"), on_call_tool=call_tool) + + async def collect(message: IncomingMessage) -> None: + if isinstance(message, ElicitCompleteNotification): + notifications.append(message) + completed.set() + + async with connect(server, message_handler=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + required = UrlElicitationRequiredError.from_error(exc_info.value.error) + assert [e.elicitation_id for e in required.elicitations] == [elicitation_id] + + # The out-of-band interaction completes; the server announces it on the same session. + await captured[0].send_elicit_complete(elicitation_id) + with anyio.fail_after(5): + await completed.wait() + assert notifications[0].params.elicitation_id == elicitation_id + + authorised[0] = True + result = await client.call_tool("read_files", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="contents")])) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py new file mode 100644 index 0000000000..027c80505d --- /dev/null +++ b/tests/interaction/lowlevel/test_initialize.py @@ -0,0 +1,376 @@ +"""Initialization handshake against the low-level Server, driven through the public Client API. + +The later tests drive a bare ClientSession over an InMemoryTransport instead: Client always +performs the full handshake with the latest protocol version, so skipping initialization or +requesting a different version can only be expressed one level down. The final test goes one step +further and plays the server's side of the wire by hand, because no real Server can be made to +answer initialize with an unsupported protocol version. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_PARAMS, + CallToolResult, + ClientCapabilities, + CompletionsCapability, + EmptyResult, + ErrorData, + Icon, + Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + JSONRPCRequest, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + LoggingCapability, + PromptsCapability, + ResourcesCapability, + ServerCapabilities, + TextContent, + ToolsCapability, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:initialize:basic") +@requirement("lifecycle:initialize:server-info") +async def test_initialize_returns_server_info(connect: Connect) -> None: + """Every identity field the server declares is returned to the client in server_info.""" + server = Server( + "greeter", + version="1.2.3", + title="Greeter", + description="Greets people.", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + + async with connect(server) as client: + server_info = client.initialize_result.server_info + + assert server_info == snapshot( + Implementation( + name="greeter", + title="Greeter", + description="Greets people.", + version="1.2.3", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + ) + + +@requirement("lifecycle:initialize:instructions") +async def test_initialize_returns_instructions(connect: Connect) -> None: + """Instructions are returned when the server declares them and omitted when it does not.""" + async with connect(Server("guided", instructions="Call the add tool.")) as client: + assert client.initialize_result.instructions == snapshot("Call the add tool.") + + async with connect(Server("unguided")) as client: + assert client.initialize_result.instructions is None + + +@requirement("lifecycle:initialize:capabilities:from-handlers") +@requirement("tools:capability:declared") +@requirement("resources:capability:declared") +@requirement("prompts:capability:declared") +@requirement("completion:capability:declared") +async def test_initialize_capabilities_reflect_registered_handlers(connect: Connect) -> None: + """Each feature area with a registered handler is advertised as a capability. + + The in-memory transport connects with default initialization options, so the + list_changed flags are always False regardless of the server's notification behaviour. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> types.EmptyResult: + """Registered only so the subscribe sub-capability is advertised; never called.""" + raise NotImplementedError + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered only so the prompts capability is advertised; never called.""" + raise NotImplementedError + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> types.EmptyResult: + """Registered only so the logging capability is advertised; never called.""" + raise NotImplementedError + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + server = Server( + "full", + on_list_tools=list_tools, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + on_list_prompts=list_prompts, + on_set_logging_level=set_logging_level, + on_completion=completion, + ) + + async with connect(server) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot( + ServerCapabilities( + experimental={}, + logging=LoggingCapability(), + prompts=PromptsCapability(list_changed=False), + resources=ResourcesCapability(subscribe=True, list_changed=False), + tools=ToolsCapability(list_changed=False), + completions=CompletionsCapability(), + ) + ) + + +@requirement("lifecycle:initialize:capabilities:minimal") +async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: + """A server with no feature handlers advertises no feature capabilities.""" + async with connect(Server("bare")) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot(ServerCapabilities(experimental={})) + + +@requirement("lifecycle:initialize:client-info") +async def test_initialize_server_sees_client_info(connect: Connect) -> None: + """The client identity supplied to Client is visible to server handlers after initialization.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="whoami", description="Report the caller.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + assert ctx.session.client_params is not None + client_info = ctx.session.client_params.client_info + return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) + + +@requirement("lifecycle:initialize:client-capabilities") +async def test_initialize_server_sees_client_capabilities(connect: Connect) -> None: + """The client capabilities visible to the server reflect which callbacks the client configured.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="abilities", description="Report capabilities.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "abilities" + assert ctx.session.client_params is not None + capabilities = ctx.session.client_params.capabilities + declared = [ + name + for name, value in ( + ("sampling", capabilities.sampling), + ("elicitation", capabilities.elicitation), + ) + if value is not None + ] + if capabilities.roots is not None: + declared.append(f"roots(list_changed={capabilities.roots.list_changed})") + return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) + + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + """Registered only so the client declares the roots capability; never called.""" + raise NotImplementedError + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) + + +@requirement("lifecycle:requests-before-initialized") +async def test_request_before_initialization_is_rejected() -> None: + """A feature request sent before the handshake completes is rejected; ping is exempt. + + Client always initializes on entry, so this drives a bare ClientSession that never sends + initialize. The server's stated reason for the rejection never reaches the client: the error + is reported as a generic invalid-params failure. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + """Registered so the request is routed to a real handler; never reached.""" + raise NotImplementedError + + server = Server("strict", on_list_tools=list_tools) + + async with InMemoryTransport(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.send_request(ListToolsRequest(), ListToolsResult) + + # Ping is explicitly permitted before initialization completes. + pong = await session.send_ping() + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) + assert pong == snapshot(EmptyResult()) + + +@requirement("lifecycle:version:match") +@requirement("lifecycle:version:server-fallback-latest") +async def test_initialize_negotiates_protocol_version() -> None: + """The server echoes a supported requested version and answers an unsupported one with its latest. + + Client always requests the latest version, so each half hand-builds an InitializeRequest on a + bare ClientSession to control the requested version. + """ + server = Server("negotiator") + + def initialize_request(protocol_version: str) -> InitializeRequest: + return InitializeRequest( + params=InitializeRequestParams( + protocol_version=protocol_version, + capabilities=ClientCapabilities(), + client_info=Implementation(name="time-traveller", version="0.0.1"), + ) + ) + + async with InMemoryTransport(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + with anyio.fail_after(5): + result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) + assert result.protocol_version == snapshot("2025-03-26") + + async with InMemoryTransport(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + with anyio.fail_after(5): + result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) + assert result.protocol_version == snapshot("2025-11-25") + + +@requirement("lifecycle:version:reject-unsupported") +async def test_unsupported_server_protocol_version_fails_initialization() -> None: + """An initialize response carrying a protocol version the client does not support fails initialization. + + A real Server only ever answers with a version it supports, so this test alone plays the + server's side of the wire by hand: it reads the initialize request off the raw stream and + answers it with a hand-built result. Reserve this pattern for behaviour no real server can + be made to produce. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="1991-08-06", + capabilities=ServerCapabilities(), + server_info=Implementation(name="relic", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(scripted_server) + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + with pytest.raises(RuntimeError) as exc_info: + await session.initialize() + + assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") + + +@requirement("lifecycle:version:downgrade") +async def test_an_older_supported_protocol_version_from_the_server_is_accepted() -> None: + """An initialize response carrying an older supported protocol version completes the handshake at that version. + + A real Server answers with the version the client requested (or its own latest), so this test + plays the server's side of the wire by hand to return a fixed older version regardless of what + was requested. Reserve this pattern for behaviour no real server can be made to produce. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-06-18", + capabilities=ServerCapabilities(), + server_info=Implementation(name="conservative", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(scripted_server) + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + initialize_result = await session.initialize() + + assert initialize_result.protocol_version == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py new file mode 100644 index 0000000000..eb20db207b --- /dev/null +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -0,0 +1,103 @@ +"""List-changed notifications from the low-level Server, driven through the public Client API. + +The notifications are emitted from inside a tool call, so the ordering guarantee described in +test_logging.py applies: they reach the client's message handler before the tool call returns, +and the tests assert on a plain collected list with no synchronisation. The collector records +every message the handler receives, so the assertions also prove nothing else was delivered. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolResult, + PromptListChangedNotification, + ResourceListChangedNotification, + TextContent, + ToolListChangedNotification, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list-changed") +async def test_tool_list_changed_notification(connect: Connect) -> None: + """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="install", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "install" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="installed")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("install", {}) + + assert received == snapshot([ToolListChangedNotification()]) + + +@requirement("resources:list-changed") +async def test_resource_list_changed_notification(connect: Connect) -> None: + """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="mount", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "mount" + await ctx.session.send_resource_list_changed() + return CallToolResult(content=[TextContent(text="mounted")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("mount", {}) + + assert received == snapshot([ResourceListChangedNotification()]) + + +@requirement("prompts:list-changed") +async def test_prompt_list_changed_notification(connect: Connect) -> None: + """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="learn", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "learn" + await ctx.session.send_prompt_list_changed() + return CallToolResult(content=[TextContent(text="learned")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("learn", {}) + + assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py new file mode 100644 index 0000000000..792334ecd2 --- /dev/null +++ b/tests/interaction/lowlevel/test_logging.py @@ -0,0 +1,113 @@ +"""Logging interactions against the low-level Server, driven through the public Client API. + +Notification ordering: the in-memory transport delivers every server-to-client message on one +ordered stream, and the client's receive loop dispatches each incoming message to completion +before reading the next one. Together these guarantee that every notification the server sends +before its response reaches the client callback before the originating request returns, so tests +collect notifications into a plain list and assert after the request completes -- no events, no +waiting. This does not generalise to transports that split messages across streams (the +streamable HTTP standalone GET stream); tests over those transports must synchronise explicitly. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ALL_LEVELS: tuple[types.LoggingLevel, ...] = ( + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency", +) + + +@requirement("logging:set-level") +async def test_set_logging_level_reaches_handler(connect: Connect) -> None: + """The level requested by the client is delivered to the server's handler verbatim.""" + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + assert params.level == "warning" + return EmptyResult() + + server = Server("logger", on_set_logging_level=set_logging_level) + + async with connect(server) as client: + result = await client.set_logging_level("warning") + + assert result == snapshot(EmptyResult()) + + +@requirement("logging:message:fields") +@requirement("tools:call:logging-mid-execution") +async def test_log_messages_reach_logging_callback_in_order(connect: Connect) -> None: + """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. + + The two messages pin the full notification shape: severity, optional logger name, and both + string and structured data payloads. + """ + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="chatty", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "chatty" + await ctx.session.send_log_message(level="info", data="starting up", logger="app.lifecycle") + await ctx.session.send_log_message(level="error", data={"code": 502, "retryable": True}) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, logging_callback=collect) as client: + result = await client.call_tool("chatty", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="info", logger="app.lifecycle", data="starting up"), + LoggingMessageNotificationParams(level="error", data={"code": 502, "retryable": True}), + ] + ) + + +@requirement("logging:message:all-levels") +async def test_log_messages_at_every_severity_level(connect: Connect) -> None: + """Each of the eight RFC 5424 severity levels is deliverable as a log message notification.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="siren", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "siren" + for level in ALL_LEVELS: + await ctx.session.send_log_message(level=level, data=f"a {level} message") + return CallToolResult(content=[TextContent(text="logged")]) + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, logging_callback=collect) as client: + await client.call_tool("siren", {}) + + assert [params.level for params in received] == list(ALL_LEVELS) diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py new file mode 100644 index 0000000000..a9e4f994d8 --- /dev/null +++ b/tests/interaction/lowlevel/test_meta.py @@ -0,0 +1,63 @@ +"""Request and result _meta round trips against the low-level Server, through the public Client API. + +Meta is opaque pass-through data, so these tests assert identity against the value that was sent +rather than snapshotting a literal: the expected value and the sent value are the same variable, +which also proves the SDK injected nothing alongside it. +""" + +import pytest + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("meta:request-to-handler") +async def test_request_meta_reaches_handler(connect: Connect) -> None: + """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" + request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} + observed_metas: list[dict[str, object]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="traced", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "traced" + assert ctx.meta is not None + observed_metas.append(dict(ctx.meta)) + return CallToolResult(content=[TextContent(text="traced")]) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("traced", {}, meta=request_meta) + + assert observed_metas == [dict(request_meta)] + + +@requirement("meta:result-to-client") +async def test_result_meta_reaches_client(connect: Connect) -> None: + """The _meta object a handler attaches to its result is delivered to the client unchanged.""" + result_meta = {"example.com/cost": 3} + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="metered", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "metered" + return CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("metered", {}) + + assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py new file mode 100644 index 0000000000..0c2a0b1588 --- /dev/null +++ b/tests/interaction/lowlevel/test_pagination.py @@ -0,0 +1,239 @@ +"""Cursor pagination of the list operations against the low-level Server. + +The cursor is an opaque string chosen by the server: the suite only asserts that whatever the +handler returns as next_cursor comes back verbatim on the client's next call, not any particular +pagination scheme. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + Prompt, + Resource, + ResourceTemplate, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list:pagination") +async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: + """The next_cursor a list handler returns reaches the client, and the cursor the client sends + back on the following call reaches the handler verbatim. + """ + seen_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None # the client always sends params, even without a cursor + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListToolsResult( + tools=[Tool(name="alpha", input_schema={"type": "object"})], + next_cursor="page-2", + ) + return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + first_page = await client.list_tools() + second_page = await client.list_tools(cursor="page-2") + + assert first_page == snapshot( + ListToolsResult(tools=[Tool(name="alpha", input_schema={"type": "object"})], next_cursor="page-2") + ) + assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) + assert seen_cursors == snapshot([None, "page-2"]) + + +@requirement("pagination:exhaustion") +@requirement("tools:list:pagination") +async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: Connect) -> None: + """Following next_cursor until it is absent visits every page exactly once, in order.""" + pages: dict[str | None, tuple[str, str | None]] = { + None: ("alpha", "page-2"), + "page-2": ("beta", "page-3"), + "page-3": ("gamma", None), + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + tool_name, next_cursor = pages[params.cursor] + return ListToolsResult(tools=[Tool(name=tool_name, input_schema={"type": "object"})], next_cursor=next_cursor) + + server = Server("paginated", on_list_tools=list_tools) + + collected: list[str] = [] + cursor: str | None = None + requests_made = 0 + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + requests_made += 1 + assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" + collected.extend(tool.name for tool in result.tools) + if result.next_cursor is None: + break + cursor = result.next_cursor + + assert collected == snapshot(["alpha", "beta", "gamma"]) + assert requests_made == len(pages) + + +@requirement("pagination:client:cursor-handling") +async def test_the_client_follows_opaque_cursors_through_pages_of_varying_sizes(connect: Connect) -> None: + """The client passes a server-issued cursor back byte-for-byte and follows pages of varying sizes. + + The cursors are deliberately base64-looking strings (with padding and URL-unsafe characters) to + show the client treats them as opaque tokens; the page sizes [3, 1, 2] show the loop relies only + on next_cursor, not on a fixed page size. + """ + cursor_to_page_2 = "YWxwaGE+YnJhdm8/Y2hhcmxpZQ==" + cursor_to_page_3 = "ZGVsdGE=" + pages: dict[str | None, tuple[list[str], str | None]] = { + None: (["alpha", "beta", "gamma"], cursor_to_page_2), + cursor_to_page_2: (["delta"], cursor_to_page_3), + cursor_to_page_3: (["epsilon", "zeta"], None), + } + received_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + received_cursors.append(params.cursor) + names, next_cursor = pages[params.cursor] + return ListToolsResult( + tools=[Tool(name=name, input_schema={"type": "object"}) for name in names], next_cursor=next_cursor + ) + + server = Server("paginated", on_list_tools=list_tools) + + page_sizes: list[int] = [] + cursor: str | None = None + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + page_sizes.append(len(result.tools)) + if result.next_cursor is None: + break + cursor = result.next_cursor + + # Identity, not a snapshot: what arrived at the handler is exactly what the handler issued. + assert received_cursors == [None, cursor_to_page_2, cursor_to_page_3] + assert page_sizes == [3, 1, 2] + + +@requirement("pagination:invalid-cursor") +async def test_an_unrecognized_pagination_cursor_is_rejected_with_invalid_params(connect: Connect) -> None: + """A list request with a cursor the server did not issue is answered with -32602 Invalid params. + + The lowlevel server does not validate cursors itself (they are opaque to it); rejecting an + unrecognized cursor is the handler's job, and this test pins the spec-recommended way to do it. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + assert params.cursor == "never-issued" + raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {params.cursor!r}") + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="never-issued") + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("resources:list:pagination") +async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/list round-trips the cursor like every other list operation.""" + seen_cursors: list[str | None] = [] + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor="page-2") + return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) + + server = Server("paginated", on_list_resources=list_resources) + + async with connect(server) as client: + first_page = await client.list_resources() + second_page = await client.list_resources(cursor="page-2") + + assert seen_cursors == snapshot([None, "page-2"]) + assert [resource.name for resource in first_page.resources] == ["first"] + assert first_page.next_cursor == "page-2" + assert [resource.name for resource in second_page.resources] == ["second"] + assert second_page.next_cursor is None + + +@requirement("resources:templates:pagination") +async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/templates/list round-trips the cursor like every other list operation.""" + seen_cursors: list[str | None] = [] + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], + next_cursor="page-2", + ) + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] + ) + + server = Server("paginated", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + first_page = await client.list_resource_templates() + second_page = await client.list_resource_templates(cursor="page-2") + + assert seen_cursors == snapshot([None, "page-2"]) + assert [template.name for template in first_page.resource_templates] == ["first"] + assert first_page.next_cursor == "page-2" + assert [template.name for template in second_page.resource_templates] == ["second"] + assert second_page.next_cursor is None + + +@requirement("prompts:list:pagination") +async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None: + """prompts/list round-trips the cursor like every other list operation.""" + seen_cursors: list[str | None] = [] + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor="page-2") + return ListPromptsResult(prompts=[Prompt(name="second")]) + + server = Server("paginated", on_list_prompts=list_prompts) + + async with connect(server) as client: + first_page = await client.list_prompts() + second_page = await client.list_prompts(cursor="page-2") + + assert seen_cursors == snapshot([None, "page-2"]) + assert [prompt.name for prompt in first_page.prompts] == ["first"] + assert first_page.next_cursor == "page-2" + assert [prompt.name for prompt in second_page.prompts] == ["second"] + assert second_page.next_cursor is None diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py new file mode 100644 index 0000000000..797e20dc35 --- /dev/null +++ b/tests/interaction/lowlevel/test_ping.py @@ -0,0 +1,53 @@ +"""Ping interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:ping") +@requirement("ping:client-to-server") +async def test_client_ping_returns_empty_result(connect: Connect) -> None: + """A client ping is answered with an empty result, even by a server with no handlers.""" + server = Server("silent") + + async with connect(server) as client: + result = await client.send_ping() + + assert result == snapshot(EmptyResult()) + + +@requirement("lifecycle:ping") +@requirement("ping:server-to-client") +async def test_server_ping_returns_empty_result(connect: Connect) -> None: + """A server-initiated ping sent while a request is in flight is answered by the client. + + The tool returns the type of the ping response, proving the round trip completed inside + the handler before the tool result was produced. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ping_back", description="Ping the client.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ping_back" + pong = await ctx.session.send_ping() + return CallToolResult(content=[TextContent(text=type(pong).__name__)]) + + server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ping_back", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py new file mode 100644 index 0000000000..54faf85888 --- /dev/null +++ b/tests/interaction/lowlevel/test_progress.py @@ -0,0 +1,289 @@ +"""Progress interactions against the low-level Server, driven through the public Client API. + +Server-to-client progress emitted during a request follows the same ordering guarantee as +logging notifications (see test_logging.py): everything the server sends before its response is +dispatched to the progress callback before the request returns, so no synchronisation is needed. +The client-to-server direction is a standalone notification with no response to await, so that +test waits on an event set by the server's handler. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT +from mcp.types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:progress:callback") +@requirement("tools:call:progress") +async def test_progress_during_tool_call_reaches_callback_in_order(connect: Connect) -> None: + """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" + received: list[tuple[float, float | None, str | None]] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="download", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "download" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 1.0, total=3.0, message="first chunk") + await ctx.session.send_progress_notification(token, 2.0, total=3.0, message="second chunk") + await ctx.session.send_progress_notification(token, 3.0, total=3.0, message="done") + return CallToolResult(content=[TextContent(text="downloaded")]) + + server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("download", {}, progress_callback=collect) + + assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) + assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) + + +@requirement("protocol:progress:token-injected") +async def test_progress_token_visible_to_handler(connect: Connect) -> None: + """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def ignore(progress: float, total: float | None, message: str | None) -> None: + """A progress callback that is never invoked; the tool only inspects the token.""" + raise NotImplementedError + + async with connect(server) as client: + result = await client.call_tool("inspect", {}, progress_callback=ignore) + + # The token is the request id of the tools/call request itself (initialize is request 0). + assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + + +@requirement("protocol:progress:no-token") +async def test_no_progress_callback_means_no_token(connect: Connect) -> None: + """Without a progress callback the request carries no progress token. + + The low-level API has no way to report request-scoped progress without a token, so a handler + that sees no token has nothing to send progress against. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("inspect", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) + + +@requirement("protocol:progress:client-to-server") +async def test_client_progress_notification_reaches_server_handler(connect: Connect) -> None: + """A progress notification sent by the client is delivered to the server's progress handler.""" + received: list[ProgressNotificationParams] = [] + delivered = anyio.Event() + + async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + received.append(params) + delivered.set() + + server = Server("observer", on_progress=on_progress) + + async with connect(server) as client: + await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot( + [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] + ) + + +@requirement("protocol:progress:token-unique") +async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Connect) -> None: + """Two concurrent requests carry distinct progress tokens, and each callback sees only its own progress. + + Without the barrier the first call could run to completion before the second starts, so only one + token would be live at a time and the demultiplexing would never be exercised. The handlers each + block until both have started and then hand control back and forth so the four progress + notifications are emitted in strict a, b, a, b order on the wire. The two handlers send different + progress values so a stream swap (token A delivered to callback B and vice versa) would fail: each + callback receiving exactly its own values proves notifications are routed by token, not by arrival + order or by chance. + """ + progress_values = {"a": (1.0, 2.0), "b": (10.0, 20.0)} + tokens: dict[str, ProgressToken] = {} + entered = {"a": anyio.Event(), "b": anyio.Event()} + # turns[n] is set to release the nth emission; each emission releases the next. + turns = [anyio.Event() for _ in range(4)] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert params.arguments is not None + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + label = params.arguments["label"] + tokens[label] = token + entered[label].set() + # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. + first, second = (0, 2) if label == "a" else (1, 3) + await turns[first].wait() + await ctx.session.send_progress_notification(token, progress_values[label][0]) + turns[first + 1].set() + await turns[second].wait() + await ctx.session.send_progress_notification(token, progress_values[label][1]) + if second + 1 < len(turns): + turns[second + 1].set() + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received_a: list[float] = [] + received_b: list[float] = [] + + async def collect_a(progress: float, total: float | None, message: str | None) -> None: + received_a.append(progress) + + async def collect_b(progress: float, total: float | None, message: str | None) -> None: + received_b.append(progress) + + async with connect(server) as client: + + async def call(label: str, collect: ProgressFnT) -> None: + await client.call_tool("report", {"label": label}, progress_callback=collect) + + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + task_group.start_soon(call, "a", collect_a) + task_group.start_soon(call, "b", collect_b) + await entered["a"].wait() + await entered["b"].wait() + turns[0].set() + + assert tokens["a"] != tokens["b"] + assert received_a == [1.0, 2.0] + assert received_b == [10.0, 20.0] + + +@requirement("protocol:progress:stops-after-completion") +@requirement("protocol:progress:late-dropped-by-client") +async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback(connect: Connect) -> None: + """A progress notification sent after the response is emitted, and the client drops it from the callback. + + This single body proves both halves: the server's `send_progress_notification` happily sends for + a token whose request has already completed (the spec MUST that progress stops is not enforced; + see the divergence on `stops-after-completion`), and the client, having removed the callback when + the call returned, does not deliver the late notification to it. The message handler observes the + late notification arriving so the test knows when to assert without polling. + """ + captured: list[tuple[ServerSession, ProgressToken]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + captured.append((ctx.session, token)) + await ctx.session.send_progress_notification(token, 0.5) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[float] = [] + late_progress_arrived = anyio.Event() + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def message_handler(message: IncomingMessage) -> None: + if isinstance(message, ProgressNotification) and message.params.progress == 1.0: + late_progress_arrived.set() + + async with connect(server, message_handler=message_handler) as client: + with anyio.fail_after(5): + await client.call_tool("report", {}, progress_callback=collect) + assert received == [0.5] + + server_session, token = captured[0] + await server_session.send_progress_notification(token, 1.0) + await late_progress_arrived.wait() + + assert received == [0.5] + + +@requirement("protocol:progress:monotonic") +async def test_non_increasing_progress_values_are_forwarded_unchanged(connect: Connect) -> None: + """A handler that emits non-increasing progress values has them forwarded to the callback unchanged. + + The spec says progress MUST increase with each notification; the SDK does not enforce that on + either side. See the divergence note on the requirement. + """ + received: list[float] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="zigzag", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "zigzag" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 0.5) + await ctx.session.send_progress_notification(token, 0.3) + await ctx.session.send_progress_notification(token, 0.9) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("zigzag", {}, progress_callback=collect) + + assert received == snapshot([0.5, 0.3, 0.9]) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py new file mode 100644 index 0000000000..868b82692c --- /dev/null +++ b/tests/interaction/lowlevel/test_prompts.py @@ -0,0 +1,209 @@ +"""Prompt interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + EmbeddedResource, + ErrorData, + GetPromptResult, + Icon, + ImageContent, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("prompts:list:basic") +async def test_list_prompts_returns_registered_prompts(connect: Connect) -> None: + """The prompts returned by the handler reach the client with their argument declarations intact.""" + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + return ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + + server = Server("prompter", on_list_prompts=list_prompts) + + async with connect(server) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + ) + + +@requirement("prompts:get:with-args") +async def test_get_prompt_substitutes_arguments(connect: Connect) -> None: + """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "greet" + assert params.arguments is not None + return GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text=f"Hello, {params.arguments['name']}!"))], + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Hello, Ada!"))], + ) + ) + + +@requirement("prompts:get:multi-message") +async def test_get_prompt_multiple_messages_preserve_roles_and_order(connect: Connect) -> None: + """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "geography_quiz" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("geography_quiz") + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + ) + + +@requirement("prompts:get:no-args") +async def test_get_prompt_without_arguments_returns_the_messages(connect: Connect) -> None: + """A prompt fetched with no arguments delivers None as the handler's arguments and returns its messages.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "static" + assert params.arguments is None + return GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("static") + + assert result == snapshot( + GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + ) + + +@requirement("prompts:get:content:image") +@requirement("prompts:get:content:audio") +@requirement("prompts:get:content:embedded-resource") +async def test_get_prompt_with_non_text_content_round_trips(connect: Connect) -> None: + """Prompt messages can carry image, audio, and embedded-resource content; all reach the client. + + A single full-result snapshot proves all three content types round-trip: each block in the result + is one of the three behaviours under test. Tiny fixed base64 payloads ("aW1n" is b"img", "YXVk" + is b"aud") so the snapshot pins the exact bytes. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "media" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("media", {}) + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + ) + + +@requirement("prompts:get:unknown-name") +async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. + + The error's code and message chosen by the handler reach the client verbatim. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.name}") + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_PARAMS, message="Unknown prompt: nope")) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py new file mode 100644 index 0000000000..b6bed63a9c --- /dev/null +++ b/tests/interaction/lowlevel/test_resources.py @@ -0,0 +1,281 @@ +"""Resource interactions against the low-level Server, driven through the public Client API.""" + +import base64 + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + METHOD_NOT_FOUND, + Annotations, + BlobResourceContents, + CallToolResult, + EmptyResult, + ErrorData, + Icon, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("resources:list:basic") +@requirement("resources:annotations") +async def test_list_resources_returns_registered_resources(connect: Connect) -> None: + """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. + + The fully-populated entry includes annotations, so the snapshot also proves they round-trip. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + result = await client.list_resources() + + assert result == snapshot( + ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:read:text") +async def test_read_resource_text(connect: Connect) -> None: + """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, mime_type="text/plain", text="Hello, world!")] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///greeting.txt") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="file:///greeting.txt", mime_type="text/plain", text="Hello, world!")] + ) + ) + + +@requirement("resources:read:blob") +async def test_read_resource_binary(connect: Connect) -> None: + """Reading a binary resource returns its contents base64-encoded in the blob field.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=params.uri, + mime_type="image/png", + blob=base64.b64encode(b"\x89PNG").decode(), + ) + ] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///pixel.png") + + assert result == snapshot( + ReadResourceResult( + contents=[BlobResourceContents(uri="file:///pixel.png", mime_type="image/png", blob="iVBORw==")] + ) + ) + + +@requirement("resources:read:unknown-uri") +async def test_read_resource_unknown_uri_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches + the client verbatim. + """ + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + raise MCPError(code=-32002, message=f"Resource not found: {params.uri}") + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("file:///missing.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) + + +@requirement("resources:templates:list") +async def test_list_resource_templates_returns_registered_templates(connect: Connect) -> None: + """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + result = await client.list_resource_templates() + + assert result == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:subscribe") +async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_subscribe_resource=subscribe_resource) + + async with connect(server) as client: + result = await client.subscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:subscribe:capability-required") +async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect: Connect) -> None: + """Subscribing to a server that registered no subscribe handler is rejected with METHOD_NOT_FOUND. + + The rejection comes from no handler being registered, not from any capability check; see the + divergence on lifecycle:capability:server-not-advertised. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.subscribe_resource("file:///watched.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + + +@requirement("resources:unsubscribe") +async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" + + async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_unsubscribe_resource=unsubscribe_resource) + + async with connect(server) as client: + result = await client.unsubscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:updated-notification") +async def test_resource_updated_notification_reaches_client(connect: Connect) -> None: + """A resources/updated notification sent during a tool call reaches the client with the resource URI. + + The collector records every message the handler receives, so the assertion also proves nothing + else was delivered. + """ + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="touched")]) + + server = Server("library", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("touch", {}) + + assert received == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py new file mode 100644 index 0000000000..577b99819c --- /dev/null +++ b/tests/interaction/lowlevel/test_roots.py @@ -0,0 +1,162 @@ +"""Roots interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import FileUrl + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("roots:list:basic") +async def test_list_roots_round_trip(connect: Connect) -> None: + """A roots/list request from a tool handler is answered by the client's roots callback. + + The tool reports the URIs and names it received, proving the client's roots reached the server. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + result = await ctx.session.list_roots() + lines = [f"{root.uri} name={root.name}" for root in result.roots] + return CallToolResult(content=[TextContent(text="\n".join(lines))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult( + roots=[ + Root(uri=FileUrl("file:///home/alice/project"), name="project"), + Root(uri=FileUrl("file:///home/alice/scratch")), + ] + ) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None")] + ) + ) + + +@requirement("roots:list:empty") +async def test_list_roots_empty(connect: Connect) -> None: + """A client with no roots to offer answers roots/list with an empty list, not an error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="count_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "count_roots" + result = await ctx.session.list_roots() + return CallToolResult(content=[TextContent(text=str(len(result.roots)))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[]) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("count_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) + + +@requirement("roots:list:not-supported") +async def test_list_roots_without_callback_is_error(connect: Connect) -> None: + """A roots/list request to a client with no roots callback fails with an error the handler can observe. + + The client's default callback answers with INVALID_REQUEST rather than leaving the server + hanging; the spec names -32601 for this case (see the divergence note on the requirement). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # list_roots cannot succeed without a client callback + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + + +@requirement("roots:list:client-error") +async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connect) -> None: + """A roots callback that answers with an error fails the roots/list request with that exact error. + + The callback's code and message reach the requesting handler verbatim as an MCPError. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ErrorData: + return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) + + +@requirement("roots:list-changed") +async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: + """A roots/list_changed notification from the client is delivered to the server's handler. + + Unlike a request, a notification has no response to await: the handler sets an event and the + test waits on it, which is the only synchronisation point proving delivery. + """ + delivered = anyio.Event() + received: list[types.NotificationParams | None] = [] + + async def roots_list_changed(ctx: ServerRequestContext, params: types.NotificationParams | None) -> None: + received.append(params) + delivered.set() + + server = Server("rooted", on_roots_list_changed=roots_list_changed) + + async with connect(server) as client: + await client.send_roots_list_changed() + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot([None]) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py new file mode 100644 index 0000000000..53a246b2e8 --- /dev/null +++ b/tests/interaction/lowlevel/test_sampling.py @@ -0,0 +1,686 @@ +"""Sampling interactions against the low-level Server, driven through the public Client API. + +Each test nests a sampling/createMessage request inside a tool call: the tool handler calls +ctx.session.create_message(), the client's sampling callback answers it, and the handler +round-trips what it received back to the test through its tool result. +""" + +import pydantic +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + AudioContent, + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ErrorData, + ImageContent, + ModelHint, + ModelPreferences, + SamplingCapability, + SamplingMessage, + TextContent, + ToolResultContent, + ToolUseContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("sampling:create:basic") +@requirement("tools:call:sampling-roundtrip") +async def test_create_message_round_trip(connect: Connect) -> None: + """A handler's sampling request is answered by the client callback, and the callback's result + (role, content, model, stop reason) is returned to the handler. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=f"{result.model}/{result.stop_reason}: {result.content.text}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult( + role="assistant", + content=TextContent(text="Hello to you too."), + model="mock-llm-1", + stop_reason="endTurn", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create:include-context") +@requirement("sampling:create:model-preferences") +@requirement("sampling:create:system-prompt") +@requirement("sampling:context:server-gated-by-capability") +async def test_create_message_params_reach_callback(connect: Connect) -> None: + """Every sampling parameter the handler supplies arrives at the client callback unchanged. + + The client has not declared the sampling.context capability (Client cannot declare it), yet + include_context="thisServer" reaches the callback regardless: the spec's SHOULD NOT is not + enforced. See the divergence note on `sampling:context:server-gated-by-capability`. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + max_tokens=50, + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + stop_sequences=["\n\n", "END"], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + max_tokens=50, + stop_sequences=["\n\n", "END"], + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_request_with_image_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying image content arrives at the client callback intact. + + This is the server-to-client direction: the server includes an image in the conversation it + asks the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="describe_image", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "describe_image" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + image = params.messages[0].content + assert isinstance(image, ImageContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"described {image.mime_type} ({image.data})"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("describe_image", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_result_with_image_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is an image is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is an image rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="draw", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "draw" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Draw a cat."))], + max_tokens=100, + ) + image = result.content + assert isinstance(image, ImageContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {image.mime_type} {image.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=ImageContent(data="Y2F0", mime_type="image/png"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("draw", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) + + +@requirement("sampling:error:user-rejected") +async def test_create_message_callback_error(connect: Connect) -> None: + """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. + + The error here is the spec's own example for a user rejecting a sampling request (code -1); + the callback's code and message reach the handler verbatim, whatever they are. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: + return ErrorData(code=-1, message="User rejected sampling request") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) + + +@requirement("sampling:create-message:not-supported") +async def test_create_message_without_callback_is_error(connect: Connect) -> None: + """A sampling request to a client with no sampling callback fails with the SDK's default error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # create_message cannot succeed without a client callback + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) + + +@requirement("sampling:tools:server-gated-by-capability") +async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: + """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. + + The client supports plain sampling but cannot declare the tools sub-capability (Client does not + expose it), so the server-side validator rejects the request before anything reaches the wire. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[types.Tool(name="get_weather", input_schema={"type": "object"})], + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the validator rejects every tool-enabled request + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the plain sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) + ) + + +@requirement("sampling:tool-result:no-mixed-content") +async def test_create_message_with_unbalanced_tool_messages_is_rejected(connect: Connect) -> None: + """A sampling request whose messages mix tool results with other content never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, even + when no tools are passed, so the client callback is never invoked and the handler observes the + ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "summarise_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + ToolResultContent(tool_use_id="call-1", content=[TextContent(text="42")]), + TextContent(text="Also, a comment alongside the result."), + ], + ) + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("summarise_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="ValueError: The last message must contain only tool_result content if any is present") + ] + ) + ) + + +@requirement("sampling:capability:declare") +async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: + """A client connecting with a sampling callback advertises the sampling capability to the server. + + Client cannot declare any sub-capabilities (it does not expose ClientSession's + sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + """ + captured: list[SamplingCapability | None] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + captured.append(ctx.session.client_params.capabilities.sampling) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Registered only so the sampling capability is advertised; never called.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + await client.call_tool("capabilities", {}) + + assert captured == snapshot([SamplingCapability()]) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_request_with_audio_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying audio content arrives at the client callback intact. + + This is the server-to-client direction: the server includes audio in the conversation it asks + the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="transcribe", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "transcribe" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + audio = params.messages[0].content + assert isinstance(audio, AudioContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"transcribed {audio.mime_type} ({audio.data})"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("transcribe", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="transcribed audio/wav (c25k)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_result_with_audio_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is audio is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is audio rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="speak", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "speak" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello, aloud."))], + max_tokens=100, + ) + audio = result.content + assert isinstance(audio, AudioContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {audio.mime_type} {audio.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=AudioContent(data="aGVsbG8=", mime_type="audio/wav"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("speak", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-audio-1: audio/wav aGVsbG8=")])) + + +@requirement("sampling:message:content-cardinality") +async def test_create_message_with_list_valued_message_content_reaches_callback(connect: Connect) -> None: + """A sampling message whose content is a list of blocks arrives at the client callback as a list.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="caption", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "caption" + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + content = params.messages[0].content + assert isinstance(content, list) + return CreateMessageResult( + role="assistant", content=TextContent(text=f"{len(content)} blocks"), model="mock-llm-1" + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("caption", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="2 blocks")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:tool-use:result-balance") +async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejected(connect: Connect) -> None: + """A sampling request whose tool_result ids do not match the preceding tool_use ids never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, so the + client callback is never invoked and the handler observes the ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="continue_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "continue_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="assistant", + content=[ToolUseContent(id="call-1", name="weather", input={})], + ), + SamplingMessage( + role="user", + content=[ToolResultContent(tool_use_id="call-WRONG", content=[TextContent(text="42")])], + ), + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("continue_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match" + ) + ] + ) + ) + + +@requirement("sampling:result:no-tools-single-content") +async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_validation_error(connect: Connect) -> None: + """An array-content sampling result for a tool-free request is accepted by the client and fails server-side. + + Only the exception type is asserted: the message is pydantic's, which changes across releases. + See the divergence note on the requirement: the intended behaviour is that the client rejects + the result; instead the client accepts it and the server's response parsing raises. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Two thoughts, please."))], + max_tokens=100, + ) + except pydantic.ValidationError as exc: + return CallToolResult(content=[TextContent(text=type(exc).__name__)]) + raise NotImplementedError # the array-content result fails server-side parsing every time + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + return CreateMessageResultWithTools( + role="assistant", + content=[TextContent(text="First thought."), TextContent(text="Second thought.")], + model="mock-llm-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ValidationError")])) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py new file mode 100644 index 0000000000..a9c83d641d --- /dev/null +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -0,0 +1,114 @@ +"""Request timeouts against the low-level Server, driven through the public Client API. + +The handler blocks on an event that is never set, so the awaited response can never arrive and +any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore +set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") +async def test_request_timeout_fails_the_pending_call() -> None: + """A request whose response does not arrive within its read timeout fails with a timeout error. + + No cancellation is sent to the server (see the divergence note on the requirement): the handler + starts and is still running after the caller has already given up. The test waits for the + handler to have started only after the timeout has fired, so the timeout itself races nothing. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + # The request was already on the wire: the handler still runs even though the caller gave up. + with anyio.fail_after(5): + await handler_started.wait() + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + ) + ) + + +@requirement("protocol:timeout:session-survives") +async def test_session_serves_requests_after_timeout() -> None: + """A timed-out request does not poison the session: the next request succeeds.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError): + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:timeout:session-default") +async def test_session_level_timeout_applies_to_every_request() -> None: + """A read timeout configured on the client applies to requests that do not set their own.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the + # per-request timeouts: a session-level timeout also governs the initialize handshake, so the + # value must be long enough for the in-process handshake to complete before the blocked tool + # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual + # latency; lowering it only erodes the margin against CI scheduler jitter without saving + # anything perceptible. + async with Client(server, read_timeout_seconds=0.05) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + ) + ) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py new file mode 100644 index 0000000000..95bb6bd790 --- /dev/null +++ b/tests/interaction/lowlevel/test_tools.py @@ -0,0 +1,512 @@ +"""Tool interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + CallToolResult, + EmbeddedResource, + ErrorData, + Icon, + ImageContent, + ListToolsResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, + ToolAnnotations, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="add", description="Add two integers.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) + + +@requirement("tools:call:is-error") +async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) -> None: + """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. + + Tool execution errors are part of the result so the caller (typically a model) can see + them; only protocol-level failures become JSON-RPC errors. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "flux" + return CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + ) + + +@requirement("tools:call:unknown-name") +async def test_call_tool_unknown_tool_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. + + The error's code, message, and data chosen by the handler reach the client verbatim. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown tool: {params.name}", data={"requested": params.name}) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("nope", {}) + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Unknown tool: nope", data={"requested": "nope"}) + ) + + +@requirement("protocol:error:internal-error") +async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: + """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. + + The low-level server reports it with code 0 and the exception text as the message; see the + divergence note on the requirement. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "explode" + raise ValueError("boom") + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + + +@requirement("tools:list:basic") +async def test_list_tools_returns_registered_tools(connect: Connect) -> None: + """The tools advertised by the server's list handler arrive at the client unchanged.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + + server = Server("calculator", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + ) + + +@requirement("tools:input-schema:json-schema-2020-12") +@requirement("tools:input-schema:preserve-additional-properties") +@requirement("tools:input-schema:preserve-defs") +@requirement("tools:input-schema:preserve-schema-dialect") +async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Connect) -> None: + """A rich JSON Schema 2020-12 inputSchema reaches the client unchanged and the tool is callable. + + The single identity assertion below proves all four pass-through behaviours at once: the same + dict literal that was registered is the dict that arrives, so $schema, $defs, the nested object + property, and additionalProperties are each preserved by virtue of the whole schema being + preserved. The follow-up call proves the rich-schema tool is callable end to end. + """ + schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": {"positive": {"type": "integer", "exclusiveMinimum": 0}}, + "properties": { + "count": {"$ref": "#/$defs/positive"}, + "options": { + "type": "object", + "properties": {"verbose": {"type": "boolean"}}, + "additionalProperties": False, + }, + }, + "required": ["count"], + "additionalProperties": False, + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="typed", input_schema=schema)]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "typed" + assert params.arguments == {"count": 3, "options": {"verbose": True}} + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + listed = await client.list_tools() + called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) + + assert listed.tools[0].input_schema == schema + assert called == snapshot(CallToolResult(content=[TextContent(text="ok")])) + + +@requirement("tools:list:metadata") +async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: + """Every optional Tool field the server supplies reaches the client unchanged.""" + + tool = Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[tool]) + + server = Server("annotated", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + ] + ) + ) + + +@requirement("tools:call:content:mixed") +@requirement("tools:call:content:image") +@requirement("tools:call:content:audio") +@requirement("tools:call:content:resource-link") +@requirement("tools:call:content:embedded-resource") +async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: + """A tool result can mix every content block type; all of them arrive in order. + + The payloads are tiny fixed base64 strings ("aW1n" is b"img", "YXVk" is b"aud") so the + snapshot pins the exact bytes the client receives. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="render", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "render" + return CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + + server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("render", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + ) + + +@requirement("tools:call:structured-content") +async def test_call_tool_structured_content(connect: Connect) -> None: + """A tool result carrying structured content alongside content delivers both to the client.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="sum", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "sum" + return CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5}) + + server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("sum", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) + + +@requirement("tools:call:concurrent") +async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> None: + """Two tool calls in flight at once run concurrently and each caller gets its own answer. + + Both handlers are held on a shared event after signalling that they have started, and the test + only releases them once both signals have arrived -- a server that processed requests + sequentially would never start the second handler and the test would time out instead. + """ + started: list[str] = [] + started_events = {"first": anyio.Event(), "second": anyio.Event()} + release = anyio.Event() + results: dict[str, CallToolResult] = {} + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + tag = params.arguments["tag"] + assert isinstance(tag, str) + started.append(tag) + started_events[tag].set() + await release.wait() + return CallToolResult(content=[TextContent(text=tag)]) + + server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_record(tag: str) -> None: + results[tag] = await client.call_tool("echo", {"tag": tag}) + + task_group.start_soon(call_and_record, "first") + task_group.start_soon(call_and_record, "second") + + # Both handlers are running at the same time before either is allowed to finish. + await started_events["first"].wait() + await started_events["second"].wait() + release.set() + + assert sorted(started) == ["first", "second"] + assert results == snapshot( + { + "first": CallToolResult(content=[TextContent(text="first")]), + "second": CallToolResult(content=[TextContent(text="second")]), + } + ) + + +@requirement("client:output-schema:validate") +async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client(connect: Connect) -> None: + """A result whose structured content does not conform to the tool's declared output schema never + reaches the caller: the client validates it against the schema cached from tools/list and raises. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")], structured_content={"temperature": "warm"}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + # The message embeds the jsonschema validation error, so only the SDK-authored prefix is pinned. + assert str(exc_info.value).startswith("Invalid structured content returned by tool forecast") + + +@requirement("client:output-schema:skip-on-error") +async def test_is_error_result_bypasses_client_output_schema_validation(connect: Connect) -> None: + """A tool result with isError true is returned as-is even when its structured content violates the schema. + + The schema is cached up front so the client could validate, proving the bypass is specifically the + isError flag and not an empty cache. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult( + content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True + ) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + result = await client.call_tool("forecast", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True) + ) + + +@requirement("client:output-schema:missing-structured") +async def test_declared_output_schema_with_no_structured_content_is_rejected_by_the_client(connect: Connect) -> None: + """A tool that declared an output schema but returned no structuredContent fails the client-side check. + + The error is the SDK's own message, so the full text is snapshotted. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")]) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + assert str(exc_info.value) == snapshot("Tool forecast has an output schema but did not return structured content") + + +@requirement("client:output-schema:auto-list") +async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools_list(connect: Connect) -> None: + """Calling a tool whose schema is not cached issues exactly one implicit tools/list to populate it. + + The first call_tool of an uncached tool triggers a tools/list the caller never asked for; the + second call hits the cache and does not. This is the SDK's chosen cache strategy and the cause of + the surprising behaviour where a server with only on_call_tool sees a successful call answered + with METHOD_NOT_FOUND from a request the caller never made; see the divergence on the requirement. + """ + list_calls: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + list_calls.append("called") + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + first = await client.call_tool("forecast", {}) + assert list_calls == ["called"] + second = await client.call_tool("forecast", {}) + + assert list_calls == ["called"] + assert first == snapshot(CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21})) + assert second == first diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py new file mode 100644 index 0000000000..a3453b7b2a --- /dev/null +++ b/tests/interaction/lowlevel/test_wire.py @@ -0,0 +1,303 @@ +"""Wire-level invariants observed at the client's transport boundary. + +These behaviours are invisible to API callers -- they are properties of the raw JSON-RPC frames. +The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing +the transport seam into a list without touching the session, so the assertions hold for whatever +the session implementation sends rather than for what its API returns. + +The later tests drive the wire by hand instead: one closes the server-to-client stream while a +request is in flight to pin the connection-closed teardown, and the last two send deliberately +malformed JSON-RPC requests that the typed client API cannot produce. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + TextContent, +) +from tests.interaction._helpers import RecordingTransport, _RecordingReadStream +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _echo_server() -> Server: + """A server with one echo tool, used by every test in this module.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="ok")]) + + return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("protocol:request-id:unique") +async def test_request_ids_are_unique_and_never_null() -> None: + """Every request the client sends carries a distinct, non-null id. + + The id sequence is pinned: sequential integers from zero, in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + await client.call_tool("echo", {}) + await client.call_tool("echo", {}) + await client.send_ping() + + sent = [message.message for message in recording.sent] + request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + assert all(request_id is not None for request_id in request_ids) + assert len(request_ids) == len(set(request_ids)) + # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a + # schema-cache refresh here because the explicit tools/list already populated the cache. + assert request_ids == snapshot([0, 1, 2, 3, 4]) + + +@requirement("protocol:notifications:no-response") +async def test_notifications_are_never_answered() -> None: + """A notification produces no response: everything the server sends back answers a request. + + The client sends two notifications (initialized and roots/list_changed) and several requests; + the messages received from the server must be exactly one response per request, each carrying + the id of the request it answers, and nothing else. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.send_roots_list_changed() + await client.send_ping() + + sent = [message.message for message in recording.sent] + sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)] + received = [message.message for message in recording.received if isinstance(message, SessionMessage)] + received_responses = [message for message in received if isinstance(message, JSONRPCResponse)] + + assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed + assert len(received_responses) == len(received) # nothing the server sent was anything but a response + assert [message.id for message in received_responses] == sent_request_ids + + +async def test_recording_read_stream_ends_iteration_when_the_sender_closes() -> None: + """The recording wrapper preserves the end-of-stream behaviour of the stream it wraps. + + This exercises the helper itself rather than an interaction-model behaviour: a transport whose + far end closes must end the client's receive loop cleanly, and the wrapper must not swallow or + mistranslate that. + """ + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + log: list[SessionMessage | Exception] = [] + async with send_stream, _RecordingReadStream(receive_stream, log) as wrapped: + await send_stream.aclose() + items = [item async for item in wrapped] + assert items == [] + assert log == [] + + +@requirement("lifecycle:initialized-notification") +async def test_exactly_one_initialized_notification_is_sent_after_the_handshake() -> None: + """The client sends initialized exactly once, between the initialize response and its first request. + + The full method sequence the client puts on the wire is pinned in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + + sent_methods = [ + message.message.method + for message in recording.sent + if isinstance(message.message, JSONRPCRequest | JSONRPCNotification) + ] + assert sent_methods.count("notifications/initialized") == 1 + assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) + + +@requirement("protocol:error:connection-closed") +async def test_closing_the_transport_fails_in_flight_requests_with_connection_closed() -> None: + """When the server-to-client stream closes, every in-flight client request fails with CONNECTION_CLOSED. + + Driven over a bare ClientSession against a real Server so the test holds the transport stream + pair directly: once the request is in flight (the server handler signals it has started) the + test closes the server's write stream, which ends the client's receive loop and triggers the + teardown that fails the pending request. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + errors: list[ErrorData] = [] + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + await session.initialize() + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await session.send_request( + CallToolRequest(params=CallToolRequestParams(name="block")), CallToolResult + ) + errors.append(exc_info.value.error) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(call_and_capture_error) + await handler_started.wait() + await server_write.aclose() + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=CONNECTION_CLOSED, message="Connection closed")]) + + +@requirement("protocol:error:invalid-params") +async def test_malformed_request_params_are_answered_with_invalid_params() -> None: + """A request whose params fail validation is answered with -32602 Invalid params. + + The typed client API cannot construct a request with the wrong parameter types, so the test + plays the client's side of the wire by hand against a real Server: it completes the + initialization handshake at the JSON-RPC layer and then sends a tools/call whose `name` is an + integer. Reserve this pattern for behaviour the typed API cannot produce. + """ + server = Server("strict") + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")]) + + +@requirement("logging:set-level:invalid-level") +async def test_set_level_with_an_unrecognized_value_is_answered_with_invalid_params() -> None: + """logging/setLevel with a value outside the spec's level enum is answered with -32602 Invalid params. + + The typed client API cannot construct a setLevel request with an unrecognized level (pyright and + the client-side model both reject it), so the test plays the client's side of the wire by hand + against a real Server. Reserve this pattern for behaviour the typed API cannot produce. + """ + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; never called -- params validation fails first.""" + raise NotImplementedError + + server = Server("logger", on_set_logging_level=set_logging_level) + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + ) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert len(errors) == 1 + assert errors[0].code == INVALID_PARAMS diff --git a/tests/interaction/mcpserver/__init__.py b/tests/interaction/mcpserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py new file mode 100644 index 0000000000..7761066e94 --- /dev/null +++ b/tests/interaction/mcpserver/test_completion.py @@ -0,0 +1,38 @@ +"""Completion behaviour against MCPServer, driven through the public Client API.""" + +import pytest + +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + CompletionsCapability, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:completion:capability-auto") +async def test_completion_capability_is_advertised_only_when_a_handler_is_registered(connect: Connect) -> None: + """An MCPServer with a registered completion handler advertises the completions capability; one without does not.""" + with_handler = MCPServer("completer") + + @with_handler.completion() + async def complete( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + async with connect(with_handler) as client: + assert client.initialize_result.capabilities.completions == CompletionsCapability() + + async with connect(MCPServer("plain")) as client: + assert client.initialize_result.capabilities.completions is None diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py new file mode 100644 index 0000000000..26556fea7a --- /dev/null +++ b/tests/interaction/mcpserver/test_context.py @@ -0,0 +1,271 @@ +"""The Context convenience methods MCPServer injects into tool functions, observed from the client.""" + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp import MCPError +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + METHOD_NOT_FOUND, + CallToolResult, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + ErrorData, + Implementation, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:context:logging") +@requirement("logging:capability:declared") +async def test_context_logging_helpers_send_log_notifications(connect: Connect) -> None: + """Each Context logging helper sends a log message notification at the matching severity. + + All four notifications reach the client's logging callback before the tool call returns; none + of them carry a logger name unless one is passed explicitly. The server emits these without + advertising the logging capability (see the divergence note on logging:capability). + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("chatty") + + @mcp.tool() + async def narrate(ctx: Context) -> str: + await ctx.debug("d") + await ctx.info("i") + await ctx.warning("w") + await ctx.error("e") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + result = await client.call_tool("narrate", {}) + advertised_logging = client.initialize_result.capabilities.logging + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="d"), + LoggingMessageNotificationParams(level="info", data="i"), + LoggingMessageNotificationParams(level="warning", data="w"), + LoggingMessageNotificationParams(level="error", data="e"), + ] + ) + # The spec requires servers that emit log notifications to declare the logging capability. + assert advertised_logging is None + + +@requirement("mcpserver:context:progress") +async def test_context_report_progress_sends_progress_notifications(connect: Connect) -> None: + """Context.report_progress sends progress notifications correlated to the calling request. + + The caller's progress callback receives each report, in order, before the tool call returns. + """ + received: list[tuple[float, float | None, str | None]] = [] + mcp = MCPServer("worker") + + @mcp.tool() + async def crunch(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.report_progress(2, 3, "halfway there") + return "crunched" + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with connect(mcp) as client: + result = await client.call_tool("crunch", {}, progress_callback=on_progress) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="crunched")], structured_content={"result": "crunched"}) + ) + assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) + + +@requirement("mcpserver:tool:extra") +async def test_context_exposes_request_id_and_client_info_to_a_tool(connect: Connect) -> None: + """A tool can read the per-request id and the connecting client's identity through Context. + + The request id is non-empty (its concrete value depends on transport-level sequencing, so the + test asserts the value the tool saw is the one returned, rather than pinning the literal); the + client info reflects what the caller passed to `Client`. + """ + mcp = MCPServer("introspector") + + @mcp.tool() + async def whoami(ctx: Context) -> str: + client_params = ctx.session.client_params + assert client_params is not None + return f"request {ctx.request_id} from {client_params.client_info.name} {client_params.client_info.version}" + + async with connect(mcp, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + assert text.startswith("request ") + assert text.endswith(" from acme-agent 9.9.9") + request_id = text.removeprefix("request ").removesuffix(" from acme-agent 9.9.9") + assert request_id + + +@requirement("protocol:progress:no-token") +async def test_report_progress_without_a_progress_token_sends_nothing(connect: Connect) -> None: + """When the caller supplied no progress callback, Context.report_progress is a silent no-op. + + The tool also emits one log message as a sentinel: the message handler receives only that, + proving the notification pipeline works and no progress notification was sent for the + token-less request. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("quiet") + + @mcp.tool() + async def mill(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.info("milling done") + return "milled" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + result = await client.call_tool("mill", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="milled")], structured_content={"result": "milled"}) + ) + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] + ) + + +@requirement("mcpserver:context:elicit") +@requirement("tools:call:elicitation-roundtrip") +async def test_context_elicit_returns_typed_result(connect: Connect) -> None: + """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. + + The client sees the JSON schema generated from the model; the accepted content is validated + back into the model and handed to the tool as result.data. + """ + received: list[ElicitRequestParams] = [] + mcp = MCPServer("travel") + + class TravelPreferences(BaseModel): + destination: str + window_seat: bool + + @mcp.tool() + async def book_flight(ctx: Context) -> str: + answer = await ctx.elicit("Where to?", TravelPreferences) + assert isinstance(answer, AcceptedElicitation) + return f"{answer.action}: {answer.data.destination} window={answer.data.window_seat}" + + async def answer_form(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) + + async with connect(mcp, elicitation_callback=answer_form) as client: + result = await client.call_tool("book_flight", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Where to?", + requested_schema={ + "properties": { + "destination": {"title": "Destination", "type": "string"}, + "window_seat": {"title": "Window Seat", "type": "boolean"}, + }, + "required": ["destination", "window_seat"], + "title": "TravelPreferences", + "type": "object", + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept: Lisbon window=True")], + structured_content={"result": "accept: Lisbon window=True"}, + ) + ) + + +@requirement("mcpserver:context:read-resource") +async def test_context_read_resource_reads_registered_resource(connect: Connect) -> None: + """Context.read_resource lets a tool read a resource registered on the same server. + + The tool reports the MIME type and content it read, proving the resource function ran and its + return value came back through the context. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + @mcp.tool() + async def show_config(ctx: Context) -> str: + contents = list(await ctx.read_resource("config://app")) + return "\n".join(f"{item.mime_type}: {item.content!r}" for item in contents) + + async with connect(mcp) as client: + result = await client.call_tool("show_config", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="text/plain: 'theme = dark'")], + structured_content={"result": "text/plain: 'theme = dark'"}, + ) + ) + + +@requirement("logging:message:filtered") +async def test_set_logging_level_is_rejected_and_messages_are_never_filtered(connect: Connect) -> None: + """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. + + The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, + and every message a tool emits is delivered regardless of level. The spec says the server + should only send messages at or above the configured level; with no way to configure one, + everything is sent. + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("unfilterable") + + @mcp.tool() + async def chatter(ctx: Context) -> str: + await ctx.debug("noise") + await ctx.error("signal") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.set_logging_level("error") + + await client.call_tool("chatter", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="noise"), + LoggingMessageNotificationParams(level="error", data="signal"), + ] + ) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py new file mode 100644 index 0000000000..ddea4d8278 --- /dev/null +++ b/tests/interaction/mcpserver/test_prompts.py @@ -0,0 +1,191 @@ +"""Prompt interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + GetPromptResult, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:prompt:decorated") +async def test_list_prompts_derives_arguments_from_signature(connect: Connect) -> None: + """A decorated prompt is listed with arguments derived from the function signature. + + Parameters without a default are required; the description comes from the docstring. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def code_review(code: str, style_guide: str = "pep8") -> str: + """Review a piece of code.""" + raise NotImplementedError # registered for listing only; never rendered + + async with connect(mcp) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", required=True), + PromptArgument(name="style_guide", required=False), + ], + ) + ] + ) + ) + + +@requirement("mcpserver:prompt:decorated") +async def test_get_prompt_renders_function_return(connect: Connect) -> None: + """The decorated function's string return value is rendered as a single user message.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A personalised greeting.""" + return f"Say hello to {name}." + + async with connect(mcp) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Say hello to Ada."))], + ) + ) + + +@requirement("mcpserver:prompt:unknown-name") +async def test_get_unknown_prompt_is_error(connect: Connect) -> None: + """Getting a prompt name that was never registered fails with a JSON-RPC error.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; the test requests a different name.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) + + +@requirement("prompts:get:missing-required-args") +async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: Connect) -> None: + """Getting a prompt without one of its required arguments fails with a JSON-RPC error. + + The missing argument is detected before the prompt function is called, but the spec's -32602 + Invalid params is reported as error code 0 with the bare exception text (see the divergence + note on the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("greet") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) + + +@requirement("mcpserver:prompt:args-validation") +async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_function_runs(connect: Connect) -> None: + """An argument that fails the function signature's type validation is rejected before the function runs. + + The decorated function is wrapped in pydantic's validate_call, so a value that cannot be + coerced to the parameter's annotation fails before the body executes. The function body + raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable + rendering-error prefix; the body of the message is raw pydantic output and is not asserted. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def repeat(phrase: str, count: int) -> str: + """A registered prompt; type validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) + + assert exc_info.value.error.code == 0 + assert exc_info.value.error.message.startswith("Error rendering prompt repeat: 1 validation error") + + +@requirement("mcpserver:prompt:optional-args") +async def test_get_prompt_with_an_optional_argument_omitted_uses_the_default(connect: Connect) -> None: + """A prompt rendered without one of its optional arguments uses that parameter's default value.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def review(code: str, style: str = "pep8") -> str: + """Review a snippet of code against a style guide.""" + return f"Review {code} per {style}." + + async with connect(mcp) as client: + result = await client.get_prompt("review", {"code": "x = 1"}) + + assert result == snapshot( + GetPromptResult( + description="Review a snippet of code against a style guide.", + messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], + ) + ) + + +@requirement("mcpserver:prompt:duplicate-name") +async def test_registering_a_duplicate_prompt_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second prompt with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via the decorator with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.prompt(name="greet") + def greet_second() -> str: + """Registered with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_prompts() + result = await client.get_prompt("greet") + + assert [prompt.name for prompt in listed.prompts] == ["greet"] + assert result == snapshot( + GetPromptResult( + description="The first registration; this is the one that wins.", + messages=[PromptMessage(role="user", content=TextContent(text="first"))], + ) + ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py new file mode 100644 index 0000000000..57b0fdc86d --- /dev/null +++ b/tests/interaction/mcpserver/test_resources.py @@ -0,0 +1,183 @@ +"""Resource interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:resource:static") +async def test_read_static_resource(connect: Connect) -> None: + """A function registered for a fixed URI is served at that URI with its return value as text.""" + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + async with connect(mcp) as client: + result = await client.read_resource("config://app") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="theme = dark")] + ) + ) + + +@requirement("mcpserver:resource:static") +async def test_list_static_and_templated_resources(connect: Connect) -> None: + """Statically-registered resources appear in resources/list; templated ones only in templates/list. + + The name and description are derived from the function name and docstring; the MIME type + defaults to text/plain. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + raise NotImplementedError # registered for listing only; never read + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + raise NotImplementedError # registered for listing only; never read + + async with connect(mcp) as client: + resources = await client.list_resources() + templates = await client.list_resource_templates() + + assert resources == snapshot( + ListResourcesResult( + resources=[ + Resource( + name="app_config", + uri="config://app", + description="The application configuration.", + mime_type="text/plain", + ) + ] + ) + ) + assert templates == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate( + name="user_profile", + uri_template="users://{user_id}/profile", + description="A user's profile.", + mime_type="text/plain", + ) + ] + ) + ) + + +@requirement("mcpserver:resource:template") +@requirement("resources:read:template-vars") +async def test_read_templated_resource(connect: Connect) -> None: + """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" + mcp = MCPServer("library") + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + return f"profile for {user_id}" + + async with connect(mcp) as client: + result = await client.read_resource("users://42/profile") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="users://42/profile", mime_type="text/plain", text="profile for 42")] + ) + ) + + +@requirement("mcpserver:resource:unknown-uri") +async def test_read_unknown_uri_is_error(connect: Connect) -> None: + """Reading a URI that matches no registered resource fails with a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """A registered resource; the test reads a different URI.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("config://missing") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) + + +@requirement("mcpserver:resource:read-throws-surfaced") +async def test_resource_function_that_raises_is_surfaced_as_a_jsonrpc_error(connect: Connect) -> None: + """An exception raised by a resource function reaches the caller as a JSON-RPC error. + + MCPServer wraps the failure in a generic error that names only the URI, so the original + exception text is not leaked to the client. The wrapped exception becomes error code 0 the + same way every other unhandled server-side exception does. + """ + mcp = MCPServer("library") + + @mcp.resource("res://boom") + def boom() -> str: + raise RuntimeError("nope") + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("res://boom") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) + + +@requirement("mcpserver:resource:duplicate-name") +async def test_registering_a_duplicate_resource_uri_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second static resource at an already-used URI keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The two + registrations use different function names so the test does not redefine a name in this scope; + the resource decorator keys on the URI, not the function name. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def config_first() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.resource("config://app") + def config_second() -> str: + """Registered at a duplicate URI; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_resources() + result = await client.read_resource("config://app") + + assert [resource.uri for resource in listed.resources] == ["config://app"] + assert listed.resources[0].name == "config_first" + assert result == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="first")]) + ) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py new file mode 100644 index 0000000000..f8aa208d7f --- /dev/null +++ b/tests/interaction/mcpserver/test_tools.py @@ -0,0 +1,397 @@ +"""Tool interactions against MCPServer, driven through the public Client API.""" + +from typing import Annotated, Literal + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, Field + +from mcp import MCPError +from mcp.server.mcpserver import Context, MCPServer +from mcp.server.mcpserver.exceptions import ToolError +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitRequestURLParams, + ErrorData, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool function; its return value comes back as text content. + + MCPServer also derives an output schema from the return annotation and attaches the + matching structuredContent to the result. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + return str(a + b) + + async with connect(mcp) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) + + +@requirement("mcpserver:tool:schema-variants") +async def test_complex_parameter_types_are_validated_and_coerced_before_the_tool_runs(connect: Connect) -> None: + """Literal, nested-model, and constrained parameters are validated and coerced from the wire arguments. + + The string "3" is coerced to `int` and the `point` dict to a `Point` instance before the function + body sees them, proving the generated input schema and validation pipeline cover non-trivial types. + """ + mcp = MCPServer("typed") + + class Point(BaseModel): + x: int + y: int + + @mcp.tool() + def place(mode: Literal["fast", "slow"], point: Point, count: Annotated[int, Field(ge=1, le=10)]) -> str: + assert isinstance(point, Point) + return f"{mode} at ({point.x}, {point.y}) x{count}" + + async with connect(mcp) as client: + result = await client.call_tool("place", {"mode": "fast", "point": {"x": "3", "y": 4}, "count": 5}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="fast at (3, 4) x5")], structured_content={"result": "fast at (3, 4) x5"} + ) + ) + + +@requirement("mcpserver:tool:handler-throws") +@requirement("mcpserver:output-schema:skip-on-error") +async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: + """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error. + + The function's `-> str` annotation gives the tool a derived output schema, but the error + result is built before any schema validation runs, so no validation failure is layered on + top of the original exception. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def explode() -> str: + raise ValueError("boom") + + async with connect(mcp) as client: + result = await client.call_tool("explode", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool explode: boom")], is_error=True) + ) + + +@requirement("mcpserver:tool:handler-throws") +async def test_call_tool_tool_error_becomes_error_result(connect: Connect) -> None: + """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + mcp = MCPServer("errors") + + @mcp.tool() + def flux() -> str: + raise ToolError("flux capacitor offline") + + async with connect(mcp) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool flux: flux capacitor offline")], is_error=True) + ) + + +@requirement("mcpserver:tool:unknown-name") +async def test_call_tool_unknown_name_returns_error_result(connect: Connect) -> None: + """Calling a tool name that was never registered is reported as an is_error result. + + The spec classifies unknown tools as a protocol error; see the divergence note on the + requirement. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def add() -> None: + """A registered tool; the test calls a different name.""" + + async with connect(mcp) as client: + result = await client.call_tool("nope", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) + + +@requirement("mcpserver:tool:output-schema:model") +@requirement("tools:call:structured-content:text-mirror") +async def test_call_tool_model_return_becomes_structured_content(connect: Connect) -> None: + """A tool returning a pydantic model advertises the model's schema as the tool's output schema + and returns the model's fields as structured content alongside a serialised text block. + """ + mcp = MCPServer("weather") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def get_weather() -> Weather: + return Weather(temperature=22.5, conditions="sunny") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("get_weather", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": { + "temperature": {"title": "Temperature", "type": "number"}, + "conditions": {"title": "Conditions", "type": "string"}, + }, + "required": ["temperature", "conditions"], + "title": "Weather", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="""\ +{ + "temperature": 22.5, + "conditions": "sunny" +}\ +""" + ) + ], + structured_content={"temperature": 22.5, "conditions": "sunny"}, + ) + ) + + +@requirement("mcpserver:tool:output-schema:wrapped") +async def test_call_tool_list_return_is_wrapped_in_result_key(connect: Connect) -> None: + """A tool returning a list wraps the value under a "result" key in both the generated output + schema and the structured content. + """ + mcp = MCPServer("primes") + + @mcp.tool() + def primes() -> list[int]: + return [2, 3, 5] + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("primes", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": {"result": {"items": {"type": "integer"}, "title": "Result", "type": "array"}}, + "required": ["result"], + "title": "primesOutput", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="2"), TextContent(text="3"), TextContent(text="5")], + structured_content={"result": [2, 3, 5]}, + ) + ) + + +@requirement("mcpserver:tool:input-validation") +async def test_call_tool_invalid_arguments_become_error_result(connect: Connect) -> None: + """Arguments that fail validation against the tool's signature are reported as an is_error + result describing the failure, not as a protocol error. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + """Validation rejects the arguments before the function is ever called.""" + raise NotImplementedError + + async with connect(mcp) as client: + result = await client.call_tool("add", {"b": 3}) + + # The description is raw pydantic output -- it embeds a pydantic-version-specific + # errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable + # prefix is asserted; a full snapshot would break on every pydantic upgrade. + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text.startswith("Error executing tool add: 1 validation error") + + +@requirement("mcpserver:output-schema:server-validate") +@requirement("mcpserver:output-schema:missing-structured") +async def test_tool_with_output_schema_returning_mismatched_structured_content_is_an_error_result( + connect: Connect, +) -> None: + """Structured content that fails the tool's own output schema is rejected on the server side. + + A tool annotated `Annotated[CallToolResult, Model]` returns a hand-built CallToolResult while + declaring `Model` as its output schema; MCPServer validates the supplied structured_content + against that schema before returning. The two cases -- a content shape that does not match, + and no structured content at all -- both fail that validation and are reported as is_error + results carrying the (raw pydantic) validation error wrapped in the SDK's stable prefix. + """ + mcp = MCPServer("forecaster") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def mismatched() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")], structured_content={"nope": True}) + + @mcp.tool() + def missing() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")]) + + async with connect(mcp) as client: + mismatched_result = await client.call_tool("mismatched", {}) + missing_result = await client.call_tool("missing", {}) + + # The body of each message is raw pydantic ValidationError output (model name, field paths, + # an errors.pydantic.dev URL) and changes across pydantic versions, so only the SDK's stable + # prefix is asserted. + assert mismatched_result.is_error is True + assert isinstance(mismatched_result.content[0], TextContent) + assert mismatched_result.content[0].text.startswith("Error executing tool mismatched: 2 validation errors") + + assert missing_result.is_error is True + assert isinstance(missing_result.content[0], TextContent) + assert missing_result.content[0].text.startswith("Error executing tool missing: 1 validation error") + + +@requirement("mcpserver:tool:duplicate-name") +async def test_registering_a_duplicate_tool_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second tool with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via add_tool with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("duplicates") + + @mcp.tool() + def echo() -> str: + return "first" + + def echo_second() -> str: + """Passed to add_tool with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + mcp.add_tool(echo_second, name="echo") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("echo", {}) + + assert [tool.name for tool in listed.tools] == ["echo"] + assert result == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + + +@requirement("mcpserver:tool:url-elicitation-error") +async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: + """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. + + MCPServer wraps every other tool exception as an is_error result; this error is special-cased + so it propagates as the JSON-RPC error the client needs in order to present the listed URL + interactions and retry the call. + """ + mcp = MCPServer("authorizer") + + @mcp.tool() + def read_files() -> str: + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("mcpserver:register:post-connect") +async def test_adding_and_removing_tools_does_not_notify_connected_clients(connect: Connect) -> None: + """Mutating the tool set on a running server changes tools/list but sends no notification. + + add_tool and remove_tool only update the registry: a connected client that listed the tools + before the mutation has no way to learn it should list them again. The spec provides + notifications/tools/list_changed for exactly this; MCPServer never sends it. The tool emits + one log message as a sentinel so the test proves notifications do reach the collector -- the + log message arrives, a list_changed does not. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("mutable") + + def extra() -> str: + """A tool registered at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + def doomed() -> str: + """A tool removed at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + async def grow(ctx: Context) -> str: + mcp.add_tool(extra, name="extra") + mcp.remove_tool("doomed") + await ctx.info("tool set changed") + return "mutated" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + before = await client.list_tools() + await client.call_tool("grow", {}) + after = await client.list_tools() + + assert [tool.name for tool in before.tools] == ["doomed", "grow"] + assert [tool.name for tool in after.tools] == ["grow", "extra"] + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="tool set changed"))] + ) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py new file mode 100644 index 0000000000..47b1b95e71 --- /dev/null +++ b/tests/interaction/test_coverage.py @@ -0,0 +1,104 @@ +"""Enforces the contract between the requirements manifest and the test suite. + +The contract runs in both directions: every non-deferred entry in :data:`REQUIREMENTS` must be +exercised by at least one test, and every test in the suite must carry at least one +`@requirement(...)` mark referencing a manifest entry. Deferral reasons that point at coverage +elsewhere in the repo must point at paths that exist. Test modules are imported directly +(rather than relying on pytest collection) so the check holds even when only this file is run. +""" + +import importlib +import re +from pathlib import Path +from types import ModuleType + +import pytest + +from tests.interaction._requirements import REQUIREMENTS, Requirement, covered_by, requirement + +_SUITE_ROOT = Path(__file__).parent +_REPO_ROOT = _SUITE_ROOT.parent.parent + +# Repo paths cited inside deferral reasons ("Covered by tests/... "). +_CITED_PATH = re.compile(r"(?:tests|src)/[\w./-]*\w") + +# Tests that exercise the suite's own helpers rather than an interaction-model behaviour. +# Anything listed here is exempt from the every-test-has-a-requirement check. +_HARNESS_SELF_TESTS = { + "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", + "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", + "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", + "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", +} + + +def _import_all_test_modules() -> list[ModuleType]: + """Import every other test module in the suite so their `@requirement` decorators register.""" + modules: list[ModuleType] = [] + for path in sorted(_SUITE_ROOT.rglob("test_*.py")): + relative = path.relative_to(_SUITE_ROOT).with_suffix("") + name = f"{__package__}.{'.'.join(relative.parts)}" + if name != __name__: + modules.append(importlib.import_module(name)) + return modules + + +def test_every_requirement_is_exercised() -> None: + """Each non-deferred requirement is covered by at least one test (deferred ones by none).""" + _import_all_test_modules() + + uncovered = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is None and not covered_by(requirement_id) + ] + assert not uncovered, f"Requirements with no test and no deferred reason: {uncovered}" + + stale_deferrals = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is not None and covered_by(requirement_id) + ] + assert not stale_deferrals, f"Deferred requirements that now have tests (remove deferred): {stale_deferrals}" + + +def test_every_test_exercises_a_requirement() -> None: + """Each test in the suite carries at least one `@requirement` mark (harness self-tests excepted).""" + all_tests = { + f"{module.__name__}.{name}" + for module in _import_all_test_modules() + for name in vars(module) + if name.startswith("test_") + } + linked_tests = {test_name for requirement_id in REQUIREMENTS for test_name in covered_by(requirement_id)} + + unlinked = sorted(all_tests - linked_tests - _HARNESS_SELF_TESTS) + assert not unlinked, f"Tests with no @requirement mark: {unlinked}" + + stale_exemptions = sorted(_HARNESS_SELF_TESTS - all_tests) + assert not stale_exemptions, f"Harness self-test exemptions that no longer exist: {stale_exemptions}" + + +def test_deferral_reasons_cite_existing_paths() -> None: + """Every repo path named in a deferral reason exists, so coverage pointers cannot rot.""" + missing = sorted( + f"{requirement_id}: {cited}" + for requirement_id, spec in REQUIREMENTS.items() + if spec.deferred is not None + for cited in _CITED_PATH.findall(spec.deferred) + if not (_REPO_ROOT / cited).exists() + ) + assert not missing, f"Deferral reasons citing paths that do not exist: {missing}" + + +def test_unknown_requirement_id_is_rejected() -> None: + """Marking a test with an ID that is not in the manifest fails at decoration time.""" + with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): + requirement("tools:call:does-not-exist") + + +def test_invalid_requirement_source_is_rejected() -> None: + """A requirement whose source is not a spec URL, 'sdk', or an issue reference fails at construction.""" + with pytest.raises(ValueError, match="source must be a specification URL"): + Requirement(source="https://example.com/not-the-spec", behavior="Never constructed.") diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py new file mode 100644 index 0000000000..6d0bfd62d4 --- /dev/null +++ b/tests/interaction/transports/_bridge.py @@ -0,0 +1,164 @@ +"""An in-process, full-duplex HTTP transport for driving ASGI applications from httpx. + +`httpx.ASGITransport` runs the application to completion and only then hands the buffered +response to the caller, so a server that streams its response — the streamable HTTP transport's +SSE responses — can never converse with the client mid-request: a server-initiated request +nested inside a still-open call deadlocks. `StreamingASGITransport` removes that limitation by +running the application as a background task and forwarding every `http.response.body` chunk to +the client the moment it is sent. Everything happens on the one event loop: no sockets, no +threads, no sleeps, no extra dependencies. + +The behavioural contract, pinned by `test_bridge.py`: + +- The request body is buffered before the application is invoked (MCP requests are small JSON + documents); the response streams chunk by chunk. +- Closing the response — or the whole client — delivers `http.disconnect` to the application, + exactly as a real server sees when its peer goes away. +- An exception the application raises before sending `http.response.start` fails the originating + request with that same exception. After the response has started, a failure is visible to the + client only through the response itself (status code, truncated body) — the same signal a real + server over a real socket would give. + +The transport owns an anyio task group for the application tasks; it is opened and closed by +`httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite +always does). Closing the transport cancels every running application task by default; set +`cancel_on_close=False` to wait for the application's own disconnect handling instead. +""" + +import math +from collections.abc import AsyncIterator +from types import TracebackType + +import anyio +import anyio.abc +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream +from starlette.types import ASGIApp, Message, Scope + + +class _StreamingResponseBody(httpx.AsyncByteStream): + """A response body that yields chunks as the application produces them. + + Closing it tells the application the client has gone away (`http.disconnect`), mirroring a + peer that drops the connection mid-response. + """ + + def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected: anyio.Event) -> None: + self._chunks = chunks + self._client_disconnected = client_disconnected + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for chunk in self._chunks: + yield chunk + + async def aclose(self) -> None: + self._client_disconnected.set() + await self._chunks.aclose() + + +class StreamingASGITransport(httpx.AsyncBaseTransport): + """Drive an ASGI application in-process, streaming each response as it is produced. + + With `cancel_on_close` (the default), closing the transport cancels every application task + still running so harness teardown can never hang. Setting it to False makes the transport wait + for the application's own disconnect handling to complete instead, which is the path the legacy + SSE server transport relies on for resource cleanup. + """ + + _task_group: anyio.abc.TaskGroup + + def __init__(self, app: ASGIApp, *, cancel_on_close: bool = True) -> None: + self._app = app + self._cancel_on_close = cancel_on_close + + async def __aenter__(self) -> "StreamingASGITransport": + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + # httpx closes every streamed response before closing the transport, so by now each + # application task has been delivered `http.disconnect`. Either cancel immediately, or wait + # for the application's own disconnect handling to unwind. + if self._cancel_on_close: + self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + request_body = b"".join([chunk async for chunk in request.stream]) + + scope: Scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?", maxsplit=1)[0], + "query_string": request.url.query, + "root_path": "", + "headers": [(name.lower(), value) for name, value in request.headers.raw], + "server": (request.url.host, request.url.port), + "client": ("127.0.0.1", 1234), + } + + request_delivered = False + client_disconnected = anyio.Event() + response_started = anyio.Event() + response_status = 0 + response_headers: list[tuple[bytes, bytes]] = [] + application_error: Exception | None = None + chunk_writer, chunk_reader = anyio.create_memory_object_stream[bytes](math.inf) + + async def receive_request() -> Message: + nonlocal request_delivered + if not request_delivered: + request_delivered = True + return {"type": "http.request", "body": request_body, "more_body": False} + await client_disconnected.wait() + return {"type": "http.disconnect"} + + async def send_response(message: Message) -> None: + nonlocal response_status, response_headers + if message["type"] == "http.response.start": + response_status = message["status"] + response_headers = list(message.get("headers", [])) + response_started.set() + return + assert message["type"] == "http.response.body" + body: bytes = message.get("body", b"") + if body: + await chunk_writer.send(body) + if not message.get("more_body", False): + await chunk_writer.aclose() + + async def run_application() -> None: + nonlocal application_error + try: + await self._app(scope, receive_request, send_response) + except Exception as exc: # The bridge is the application's outermost boundary: a crash + # must fail the originating request (or show up in the already-started response), + # never tear down the task group shared with every other in-flight request. + application_error = exc + finally: + response_started.set() + await chunk_writer.aclose() + + self._task_group.start_soon(run_application) + await response_started.wait() + if application_error is not None: + # No response will be built, so close the reader the response body would have owned. + await chunk_reader.aclose() + raise application_error + return httpx.Response( + status_code=response_status, + headers=response_headers, + stream=_StreamingResponseBody(chunk_reader, client_disconnected), + request=request, + ) diff --git a/tests/interaction/transports/_event_store.py b/tests/interaction/transports/_event_store.py new file mode 100644 index 0000000000..84d1a2646a --- /dev/null +++ b/tests/interaction/transports/_event_store.py @@ -0,0 +1,55 @@ +"""A predictable event store for resumability tests. + +The SDK's `EventStore` interface lets a streamable-HTTP server stamp every SSE event with an ID +and replay missed events when a client reconnects with `Last-Event-ID`. This implementation +issues sequential integer IDs starting at "1" so tests can assert exact IDs (the example store +uses uuid4, which cannot be snapshotted) and is small enough that every line is exercised by the +resumability tests themselves. +""" + +import anyio + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + + +class SequencedEventStore(EventStore): + """Stores every event in order and replays the same-stream tail after a given ID.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, JSONRPCMessage | None]] = [] + self._milestones: dict[int, anyio.Event] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + self._events.append((stream_id, message)) + count = len(self._events) + milestone = self._milestones.pop(count, None) + if milestone is not None: + milestone.set() + return str(count) + + async def wait_until_stored(self, count: int) -> None: + """Block until at least `count` events have been stored. + + Tests use this to wait for the server's message router (which runs in another task) to + finish storing a known set of events before issuing a replay, so the replay's content is + deterministic rather than depending on task scheduling order. + """ + if len(self._events) >= count: + return + milestone = self._milestones.setdefault(count, anyio.Event()) + await milestone.wait() + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + try: + cursor = int(last_event_id) + except ValueError: + return None + if not 0 < cursor <= len(self._events): + return None + stream_id, _ = self._events[cursor - 1] + for index in range(cursor, len(self._events)): + event_stream_id, message = self._events[index] + if event_stream_id == stream_id and message is not None: + await send_callback(EventMessage(message, str(index + 1))) + return stream_id diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py new file mode 100644 index 0000000000..fbe7e614f7 --- /dev/null +++ b/tests/interaction/transports/_stdio_server.py @@ -0,0 +1,56 @@ +"""A real low-level Server over the stdio transport, for the suite's one subprocess test. + +Runnable as `python -m tests.interaction.transports._stdio_server` from the repo root; the test +launches it that way via `stdio_client`. Kept separate from the test module so the server lives in +its own importable file (subprocess coverage applies) while the test file follows the suite's +test-only-functions convention. +""" + +import sys + +import anyio + +from mcp.server import Server, ServerRequestContext +from mcp.server.stdio import stdio_server +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + input_schema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + ) + ] + ) + + +async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + text = params.arguments["text"] + await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") + return CallToolResult(content=[TextContent(text=text)]) + + +server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool) + + +async def main() -> None: + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + # Reached only when the run loop exits because stdin closed; if the process were terminated + # the test's stderr capture would not see this line. + print("stdio-echo: clean exit", file=sys.stderr, flush=True) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py new file mode 100644 index 0000000000..71be14ced0 --- /dev/null +++ b/tests/interaction/transports/test_bridge.py @@ -0,0 +1,92 @@ +"""Contract tests for the suite's streaming ASGI bridge. + +These pin what `StreamingASGITransport` itself guarantees — chunk-by-chunk delivery, disconnect +propagation, and failure handling — against minimal hand-written ASGI applications, so the MCP +transport tests built on top of it never have to wonder what the harness provides. They are +harness self-tests, not interaction-model tests, and are exempted from the requirement-coverage +contract in `test_coverage.py`. +""" + +import anyio +import httpx +import pytest +from starlette.types import Message, Receive, Scope, Send + +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def test_response_chunks_arrive_as_the_application_sends_them() -> None: + """Each body chunk is delivered as sent, empty chunks are skipped, and the stream ends with the application.""" + + async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"first", "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": True}) + await send({"type": "http.response.body", "body": b"second", "more_body": False}) + + async with httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http: + async with http.stream("GET", "/chunks") as response: + with anyio.fail_after(5): + chunks = [chunk async for chunk in response.aiter_raw()] + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/plain" + assert chunks == [b"first", b"second"] + + +async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: + """A client that closes the response early is seen by the application as an http.disconnect.""" + seen_after_request: list[Message] = [] + disconnect_seen = anyio.Event() + + async def waiting_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + seen_after_request.append(await receive()) + disconnect_seen.set() + + async with httpx.AsyncClient(transport=StreamingASGITransport(waiting_app), base_url="http://bridge") as http: + async with http.stream("GET", "/wait") as response: + assert response.status_code == 200 + # Leaving the stream block closes the response while the application is still mid-response. + with anyio.fail_after(5): + await disconnect_seen.wait() + + assert seen_after_request == [{"type": "http.disconnect"}] + + +async def test_an_application_failure_before_the_response_starts_fails_the_request() -> None: + """An exception raised before http.response.start reaches the caller as that same exception.""" + + async def broken_app(scope: Scope, receive: Receive, send: Send) -> None: + raise RuntimeError("the demo application is broken") + + async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http: + with pytest.raises(RuntimeError, match="the demo application is broken"): + await http.get("/broken") + + +async def test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect() -> None: + """With cancel_on_close=False, an application that runs cleanup after seeing http.disconnect + completes that cleanup before the transport finishes closing.""" + cleanup_ran = anyio.Event() + + async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + await receive() + await send({"type": "http.response.start", "status": 200, "headers": []}) + assert (await receive())["type"] == "http.disconnect" + cleanup_ran.set() + + transport = StreamingASGITransport(lingering_app, cancel_on_close=False) + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + with anyio.fail_after(5): + async with http.stream("GET", "/linger") as response: + assert response.status_code == 200 + assert not cleanup_ran.is_set() + assert cleanup_ran.is_set() diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py new file mode 100644 index 0000000000..2d9d0c42b6 --- /dev/null +++ b/tests/interaction/transports/test_client_transport_http.py @@ -0,0 +1,244 @@ +"""Behaviour of the streamable-HTTP client transport itself, observed at the wire. + +These tests connect a real `Client` to a real server over the in-process bridge, recording every +HTTP request the SDK client issues, so the assertions are about what the transport sends (headers, +methods, ordering) rather than what the protocol layer on top of it returns. The recording is the +wire-level instrument; the SDK client never exposes these details. +""" + +from collections.abc import AsyncIterator + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.types import INVALID_REQUEST, CallToolResult, ErrorData, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _tooled_server() -> Server: + """A low-level server with one echo tool, used by every test in this file.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["text"]))]) + + return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + +@pytest.fixture +async def recorded() -> AsyncIterator[list[httpx.Request]]: + """Connect a `Client` over a recording HTTP client, list tools, exit, and yield every request sent. + + The HTTP client carries one caller-supplied header (`x-trace`) so its propagation can be + asserted; the recording captures the closing DELETE because it is read after the `Client` has + fully exited. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record, headers={"x-trace": "abc"}) as (http, _): + async with client_via_http(http) as client: + result = await client.list_tools() + assert [tool.name for tool in result.tools] == ["echo"] + + yield requests + + +def _after_initialize(recorded: list[httpx.Request]) -> list[httpx.Request]: + """Every recorded request after the initialize POST (which carries no session yet).""" + assert recorded[0].method == "POST" + assert "mcp-session-id" not in recorded[0].headers + return recorded[1:] + + +@requirement("client-transport:http:custom-client") +@requirement("client-transport:http:custom-headers") +async def test_the_client_uses_the_supplied_http_client_and_propagates_its_headers( + recorded: list[httpx.Request], +) -> None: + """A caller-supplied `httpx.AsyncClient` is used for every request and carries its own headers. + + The recording itself proves the supplied client is the one in use; the propagated header + proves the SDK transport does not replace the caller's client configuration. + """ + # Exact ordering past the first request is not guaranteed (the standalone GET stream is + # scheduled concurrently with later POSTs), so methods are asserted as a multiset. + assert sorted(request.method for request in recorded) == snapshot(["DELETE", "GET", "POST", "POST", "POST"]) + assert all(request.headers["x-trace"] == "abc" for request in recorded) + + +@requirement("client-transport:http:session-stored") +async def test_every_request_after_initialize_carries_the_issued_session_id(recorded: list[httpx.Request]) -> None: + """The session id from the initialize response is sent on every subsequent request.""" + session_ids = {request.headers["mcp-session-id"] for request in _after_initialize(recorded)} + assert len(session_ids) == 1 + (session_id,) = session_ids + assert session_id + + +@requirement("client-transport:http:protocol-version-stored") +@requirement("client-transport:http:protocol-version-header") +async def test_every_request_after_initialize_carries_the_negotiated_protocol_version( + recorded: list[httpx.Request], +) -> None: + """The negotiated protocol version is sent on every subsequent request (and not on initialize).""" + assert "mcp-protocol-version" not in recorded[0].headers + versions = {request.headers["mcp-protocol-version"] for request in _after_initialize(recorded)} + assert versions == snapshot({"2025-11-25"}) + + +@requirement("client-transport:http:accept-header-post") +@requirement("client-transport:http:accept-header-get") +async def test_accept_headers_cover_the_response_representations_the_transport_handles( + recorded: list[httpx.Request], +) -> None: + """POSTs accept both JSON and SSE; the standalone GET stream accepts SSE.""" + for request in recorded: + if request.method == "POST": + assert "application/json" in request.headers["accept"] + assert "text/event-stream" in request.headers["accept"] + if request.method == "GET": + assert "text/event-stream" in request.headers["accept"] + + +@requirement("client-transport:http:no-reconnect-after-close") +async def test_closing_the_client_sends_delete_and_does_not_reconnect(recorded: list[httpx.Request]) -> None: + """Client teardown sends DELETE and issues no further requests (no resumption GET).""" + assert recorded[-1].method == "DELETE" + assert all("last-event-id" not in request.headers for request in recorded) + + +@requirement("client-transport:http:concurrent-streams") +async def test_concurrent_tool_calls_each_open_a_post_stream_and_receive_their_own_response() -> None: + """Three tool calls issued at once each open their own POST stream and get the right answer.""" + requests: list[httpx.Request] = [] + results: dict[int, CallToolResult] = {} + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record) as (http, _): + async with client_via_http(http) as client: + + async def call(n: int) -> None: + results[n] = await client.call_tool("echo", {"text": str(n)}) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + for n in (1, 2, 3): + tg.start_soon(call, n) + + assert results == snapshot( + { + 1: CallToolResult(content=[TextContent(text="1")]), + 2: CallToolResult(content=[TextContent(text="2")]), + 3: CallToolResult(content=[TextContent(text="3")]), + } + ) + tools_call_posts = [r for r in requests if r.method == "POST" and b'"tools/call"' in r.content] + assert len(tools_call_posts) == 3 + + +@requirement("client-transport:http:sse-405-tolerated") +@requirement("client-transport:http:terminate-405-ok") +async def test_client_tolerates_405_on_get_and_delete() -> None: + """A 405 on the standalone GET stream or the closing DELETE does not fail the connection. + + The GET-stream task swallows the failure and schedules a reconnect that the closing cancel + interrupts before it ever sleeps the full default delay; the DELETE 405 is logged and ignored. + Neither surfaces to the caller. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + + async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] in ("GET", "DELETE"): + await send({"type": "http.response.start", "status": 405, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await real_app(scope, receive, send) + + async with server.session_manager.run(): + http_client = httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) + async with http_client: + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): + async with Client(transport) as client: + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + + +@requirement("client-transport:http:no-reconnect-after-response") +async def test_a_completed_post_stream_is_not_reconnected() -> None: + """A POST stream that delivered its response closes without a resumption GET. + + With an event store the server stamps every SSE event with an ID, so the client transport has a + Last-Event-ID it could resume from -- the test proves it does not, because the response arrived + and the stream completed normally. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + server = _tooled_server() + async with mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _): + async with client_via_http(http) as client: + with anyio.fail_after(5): + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] + assert resumption_gets == [] + + +@requirement("client-transport:http:404-surfaces") +async def test_a_404_mid_session_surfaces_as_a_session_terminated_error() -> None: + """A 404 in response to a request after initialization is reported to the caller as an MCP error. + + The spec says the client MUST start a new session in this situation; the SDK instead surfaces a + `Session terminated` error to the caller (see the divergence on the requirement). This test pins + that current behaviour. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + initialize_seen = anyio.Event() + + async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] == "POST" and initialize_seen.is_set(): + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + if scope["type"] == "http" and scope["method"] == "POST": + initialize_seen.set() + await real_app(scope, receive, send) + + async with server.session_manager.run(): + http_client = httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) + async with http_client: + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): + async with Client(transport) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools() + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py new file mode 100644 index 0000000000..6e3d787356 --- /dev/null +++ b/tests/interaction/transports/test_flows.py @@ -0,0 +1,127 @@ +"""Transport-level composed flows: multi-client isolation, reconnection, and dual-transport hosting. + +These scenarios are about how the transport layer holds together across more than one connection +or more than one transport, so they connect real `Client`s against one mounted server rather than +running over the matrix. +""" + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.session import LoggingFnT +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import client_via_http, connect_over_sse, mounted_app +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("flow:multi-client:stateful-isolation") +async def test_concurrent_clients_on_one_stateful_server_receive_only_their_own_notifications() -> None: + """Two clients on one stateful manager each receive only the notifications their own request produced. + + Complements `test_terminating_one_session_leaves_others_working` (which proves session + independence under termination) with the notification-isolation dimension: a notification + emitted by one session's handler does not leak to another session's client. + """ + mcp = MCPServer("multi") + + @mcp.tool() + async def announce(label: str, ctx: Context) -> str: + """Emit one info-level log carrying the caller's label, then return it.""" + await ctx.info(label) + return label + + received_a: list[object] = [] + received_b: list[object] = [] + + async def collect_a(params: LoggingMessageNotificationParams) -> None: + received_a.append(params.data) + + async def collect_b(params: LoggingMessageNotificationParams) -> None: + received_b.append(params.data) + + async with mounted_app(mcp) as (http, _): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call(label: str, collect: LoggingFnT) -> None: + async with client_via_http(http, logging_callback=collect) as client: + await client.call_tool("announce", {"label": label}) + + tg.start_soon(call, "a", collect_a) + tg.start_soon(call, "b", collect_b) + + assert received_a == ["a"] + assert received_b == ["b"] + + +@requirement("flow:session:terminate-then-reconnect") +async def test_a_fresh_connection_after_termination_obtains_a_new_session_and_operates() -> None: + """After a client terminates, a fresh connection to the same manager gets a distinct session. + + Steps: (1) connect a client and call list_tools, (2) the client exits (its DELETE fires), + (3) connect a second client to the same mounted app, (4) the second client's call_tool + succeeds and the recorded session ids show two distinct sessions were issued. + """ + mcp = MCPServer("reconnectable") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + session_ids: list[str] = [] + + async def record(request: httpx.Request) -> None: + session_id = request.headers.get("mcp-session-id") + if session_id is not None: + session_ids.append(session_id) + + async with mounted_app(mcp, on_request=record) as (http, _): + async with client_via_http(http) as first: + first_result = await first.list_tools() + async with client_via_http(http) as second: + second_result = await second.call_tool("echo", {"text": "again"}) + + assert {tool.name for tool in first_result.tools} == {"echo"} + assert second_result == snapshot( + CallToolResult(content=[TextContent(text="again")], structured_content={"result": "again"}) + ) + distinct = set(session_ids) + assert len(distinct) == 2, f"expected two distinct session ids across the two connections, saw {distinct}" + + +@requirement("flow:compat:dual-transport-server") +async def test_one_server_serves_streamable_http_and_sse_clients_concurrently() -> None: + """One MCPServer instance serves a streamable-HTTP client and a legacy-SSE client at the same time. + + The two transports have independent connection management (the streamable-HTTP session manager + versus a per-connection SSE handler), but both dispatch into the same server's request + handlers. The test connects one client over each transport against the same instance and + proves both reach the same tool. Uses MCPServer because the low-level Server has no SSE + convenience; the entry is about hosting composition, not the low-level API. + """ + mcp = MCPServer("dual") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + async with mounted_app(mcp) as (http, _): + async with connect_over_sse(mcp) as sse_client: + async with client_via_http(http) as shttp_client: + with anyio.fail_after(5): + shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) + sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) + + assert shttp_result == snapshot( + CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) + ) + assert sse_result == snapshot( + CallToolResult(content=[TextContent(text="via sse")], structured_content={"result": "via sse"}) + ) diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py new file mode 100644 index 0000000000..aa9beee067 --- /dev/null +++ b/tests/interaction/transports/test_hosting_http.py @@ -0,0 +1,297 @@ +"""Streamable HTTP semantics: status codes, header validation, message routing, and security. + +These tests speak HTTP directly to the server's mounted ASGI app via the in-process bridge, +asserting the wire contract -- which status code answers which condition, which stream a message +travels on -- that the SDK client never exposes. Transport-agnostic behaviour is covered by the +`connect`-fixture matrix. +""" + +import anyio +import pytest +from anyio.lowlevel import checkpoint +from httpx_sse import ServerSentEvent, aconnect_sse +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + INVALID_PARAMS, + PARSE_ERROR, + CallToolRequestParams, + CallToolResult, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListToolsResult, + PaginatedRequestParams, + TextContent, +) +from tests.interaction._connect import ( + base_headers, + initialize_body, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A low-level server with one tool that emits a related and an unrelated notification.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "narrate" + await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="done")]) + + return Server("hosted", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("hosting:http:method-405") +async def test_unsupported_http_methods_return_405() -> None: + """PUT and PATCH on the MCP endpoint return 405 with an Allow header naming the supported methods.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + put = await http.put("/mcp", json={}, headers=base_headers(session_id=session_id)) + patch = await http.patch("/mcp", json={}, headers=base_headers(session_id=session_id)) + + assert (put.status_code, put.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + assert (patch.status_code, patch.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + + +@requirement("hosting:http:accept-406") +async def test_missing_accept_media_types_return_406() -> None: + """A POST whose Accept header lacks both required types, or a GET lacking text/event-stream, returns 406.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", json=initialize_body(), headers={"accept": "text/plain", "mcp-protocol-version": "2025-11-25"} + ) + session_id = await initialize_via_http(http) + get = await http.get( + "/mcp", + headers={"accept": "application/json", "mcp-protocol-version": "2025-11-25", "mcp-session-id": session_id}, + ) + + assert (post.status_code, post.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept both application/json and text/event-stream") + ) + assert (get.status_code, get.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept text/event-stream") + ) + + +@requirement("hosting:http:content-type-415") +async def test_non_json_content_type_is_rejected() -> None: + """A POST with a non-JSON Content-Type is rejected before reaching the transport. + + See the divergence on the requirement: the security middleware rejects with 400, so the + transport's own 415 path is unreachable through any public entry point. + """ + async with mounted_app(_server()) as (http, _): + response = await http.post( + "/mcp", content=b"", headers=base_headers() | {"content-type": "text/plain"} + ) + + assert (response.status_code, response.text) == snapshot((400, "Invalid Content-Type header")) + + +@requirement("hosting:http:parse-error-400") +@requirement("hosting:http:batch") +async def test_malformed_and_batched_bodies_return_400() -> None: + """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid params.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + not_json = await http.post( + "/mcp", + content=b"this is not json", + headers=base_headers(session_id=session_id) | {"content-type": "application/json"}, + ) + batched = await http.post( + "/mcp", + json=[ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + ], + headers=base_headers(session_id=session_id), + ) + + assert not_json.status_code == 400 + assert JSONRPCError.model_validate_json(not_json.text).error.code == PARSE_ERROR + assert batched.status_code == 400 + assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_PARAMS + + +@requirement("hosting:http:protocol-version-400") +@requirement("hosting:http:protocol-version-default") +async def test_protocol_version_header_is_validated() -> None: + """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + bad = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id) | {"mcp-protocol-version": "1991-01-01"}, + ) + # Only Accept and the session ID -- no MCP-Protocol-Version header at all. + defaulted = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers={"accept": "application/json, text/event-stream", "mcp-session-id": session_id}, + ) + + assert bad.status_code == 400 + assert JSONRPCError.model_validate_json(bad.text).error.message.startswith( + "Bad Request: Unsupported protocol version: 1991-01-01." + ) + # 202 proves the request was accepted under the assumed default version (2025-03-26). + assert defaulted.status_code == 202 + + +@requirement("hosting:http:notifications-202") +async def test_notification_post_returns_202_with_no_body() -> None: + """A POST containing only a notification (no request ID) returns 202 Accepted with no body.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + response = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers=base_headers(session_id=session_id), + ) + + assert (response.status_code, response.content) == snapshot((202, b"")) + + +@requirement("hosting:http:second-sse-rejected") +async def test_a_second_standalone_get_stream_on_the_same_session_returns_409() -> None: + """Opening a second standalone GET SSE stream while one is already established returns 409 Conflict.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as first: + assert first.response.status_code == 200 + # The standalone-stream writer registers its key as its first action, then parks + # awaiting messages; one yield to the loop lets that registration complete before the + # second GET is dispatched. + await checkpoint() + second = await http.get("/mcp", headers=base_headers(session_id=session_id)) + + assert (second.status_code, second.json()["error"]["message"]) == snapshot( + (409, "Conflict: Only one SSE stream is allowed per session") + ) + + +@requirement("hosting:http:standalone-sse") +@requirement("hosting:http:standalone-sse-no-response") +@requirement("hosting:http:response-same-connection") +@requirement("hosting:http:sse-close-after-response") +@requirement("hosting:http:no-broadcast") +async def test_messages_are_routed_to_exactly_one_stream() -> None: + """Each server message travels on exactly one SSE stream and is never broadcast. + + A streamable-HTTP session has two kinds of server-to-client SSE stream: one short-lived stream + per POST request, carrying that request's response and any notifications related to it, and one + long-lived standalone stream (opened by GET) for notifications not tied to any request. The + spec's routing rule is that the POST stream delivers the response (and its related + notifications) and then closes, the standalone stream carries only unrelated notifications and + never a JSON-RPC response, and no message appears on both. The test opens both streams, calls a + tool whose handler emits one related and one unrelated notification, and asserts each message's + routing. + """ + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + post_events: list[ServerSentEvent] = [] + get_events: list[ServerSentEvent] = [] + + async def read_standalone_stream() -> None: + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as get: + assert get.response.status_code == 200 + standalone_ready.set() + async for event in get.aiter_sse(): + get_events.append(event) + seen_on_standalone.set() + + standalone_ready = anyio.Event() + seen_on_standalone = anyio.Event() + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(read_standalone_stream) + await standalone_ready.wait() + + params = CallToolRequestParams(name="narrate", arguments={}) + body = JSONRPCRequest(jsonrpc="2.0", id=5, method="tools/call", params=params.model_dump()) + async with aconnect_sse( + http, + "POST", + "/mcp", + json=body.model_dump(by_alias=True, exclude_none=True), + headers=base_headers(session_id=session_id), + ) as post: + assert post.response.status_code == 200 + # The POST stream iterator ends when the server closes the stream after the response. + post_events = [event async for event in post.aiter_sse()] + + await seen_on_standalone.wait() + tg.cancel_scope.cancel() + + post_messages = parse_sse_messages(post_events) + get_messages = parse_sse_messages(get_events) + + # POST stream: the related log notification, then the response, then the iterator ends (close). + assert [type(m).__name__ for m in post_messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) + assert isinstance(post_messages[0], JSONRPCNotification) + assert (post_messages[0].method, post_messages[0].params) == snapshot( + ("notifications/message", {"level": "info", "data": "related"}) + ) + assert isinstance(post_messages[1], JSONRPCResponse) + assert post_messages[1].id == 5 + + # Standalone stream: only the unrelated resource-updated notification, never a response. + assert [type(m).__name__ for m in get_messages] == snapshot(["JSONRPCNotification"]) + assert isinstance(get_messages[0], JSONRPCNotification) + assert get_messages[0].method == snapshot("notifications/resources/updated") + + +@requirement("hosting:http:dns-rebinding") +@requirement("transport:streamable-http:origin-validation") +async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> None: + """A disallowed Origin returns 403 (and Host 421) with protection enabled; disabled lets both through. + + See the divergence on hosting:http:dns-rebinding: the spec's Origin validation is an + unconditional MUST, but the SDK enables it only when the host is localhost (or settings are + passed explicitly) and additionally checks the Host header (returning 421), which the spec + does not require. + """ + # transport_security=None triggers the localhost auto-enable behaviour. + async with mounted_app(Server("guarded"), transport_security=None) as (http, _): + bad_origin = await http.post( + "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) + bad_host = await http.post("/mcp", json=initialize_body(), headers=base_headers() | {"host": "evil.example"}) + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://127.0.0.1:8000"} + ) as ok: + assert ok.response.status_code == 200 + assert [event async for event in ok.aiter_sse()] + + assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) + assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + + async with mounted_app( + Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) as (http, _): + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) as unguarded: + status = unguarded.response.status_code + assert [event async for event in unguarded.aiter_sse()] + + assert status == 200 diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py new file mode 100644 index 0000000000..bb98a96e7a --- /dev/null +++ b/tests/interaction/transports/test_hosting_resume.py @@ -0,0 +1,290 @@ +"""Resumability over the streamable HTTP transport, exercised entirely in process. + +These tests configure the server with an event store, so every SSE event is stamped with an ID +and a client that loses its connection can resume by sending `Last-Event-ID`. The wire-level +tests (`mounted_app` + raw httpx) assert exactly what travels on the wire; the end-to-end test +drives the SDK client through a server-initiated stream close and proves the call still +completes. The bridge's `aclose()` delivers `http.disconnect` to the running application, so +closing a streaming response mid-read is a deterministic in-process disconnect -- no sockets, +no real time. Every server here uses `retry_interval=0` so reconnection waits are no-ops. +""" + +import json + +import anyio +import httpx +import pytest +from httpx_sse import EventSource, ServerSentEvent +from inline_snapshot import snapshot + +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, + jsonrpc_message_adapter, +) +from tests.interaction._connect import ( + base_headers, + connect_over_streamable_http, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _counting_server() -> MCPServer: + """A server with one tool that emits related notifications and one unrelated notification.""" + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context, n: int) -> str: + """Emit n log notifications related to this call, plus one unrelated resource update.""" + for i in range(1, n + 1): + await ctx.info(f"tick {i}") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return f"counted to {n}" + + return mcp + + +def _tools_call(request_id: int, name: str, arguments: dict[str, object]) -> str: + """A serialized tools/call JSON-RPC request body.""" + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="tools/call", params={"name": name, "arguments": arguments} + ).model_dump_json(by_alias=True, exclude_none=True) + + +async def _read_events(response: httpx.Response, count: int) -> list[ServerSentEvent]: + """Read exactly `count` SSE events from a streaming response without closing it.""" + source = EventSource(response).aiter_sse() + return [await anext(source) for _ in range(count)] + + +@requirement("hosting:resume:event-ids") +@requirement("hosting:resume:priming") +async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_event() -> None: + """A request's SSE stream opens with a priming event (id, empty data, retry) then stamps each message.""" + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {"n": 2}), headers=base_headers(session_id=session_id) + ) as response: + assert response.status_code == 200 + events = await _read_events(response, 4) + + priming, first, second, result = events + # The priming event is the only event a client could have seen before any work happened, so it + # is the resumption anchor: it carries an ID and empty data. The SDK attaches the retry hint + # to this event (see the divergence on hosting:resume:priming). + assert (priming.id, priming.data, priming.retry) == snapshot(("3", "", 0)) + assert priming.event == snapshot("message") + # Every subsequent event carries an event-store ID; the related notifications and the response + # all ride this stream and close it after the response. + assert [event.id for event in (first, second, result)] == snapshot(["4", "5", "7"]) + assert [json.loads(event.data)["method"] for event in (first, second)] == snapshot( + ["notifications/message", "notifications/message"] + ) + assert jsonrpc_message_adapter.validate_json(result.data) == snapshot( + JSONRPCResponse( + jsonrpc="2.0", + id=1, + result={ + "content": [{"type": "text", "text": "counted to 2"}], + "structuredContent": {"result": "counted to 2"}, + "isError": False, + }, + ) + ) + + +@requirement("hosting:resume:replay") +@requirement("hosting:resume:stream-scoped") +@requirement("hosting:resume:buffered-replay") +async def test_get_with_last_event_id_replays_only_that_streams_missed_events() -> None: + """Reconnecting with Last-Event-ID returns the missed events from that one stream, in order. + + The handler also emits an unrelated notification (which the server stores under the + standalone-stream key); replay must not return it, proving replay is scoped to the stream + the given event ID belongs to. + + Steps: (1) initialize; (2) POST a tool call and read events until the first notification is + captured; (3) close the response mid-stream -- the bridge delivers `http.disconnect`, the + handler keeps running; (4) release the handler so it emits the remaining messages, which the + server buffers in the event store; (5) wait on the event store for the handler's response to + be stored, so the replay's content is independent of task scheduling; (6) GET with + `Last-Event-ID` and assert the replay is exactly the missed events from this request's stream. + """ + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context) -> str: + """Emit one related notification, wait for the test, then emit two more plus an unrelated one.""" + await ctx.info("tick 1") + await release.wait() + await ctx.info("tick 2") + await ctx.info("tick 3") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return "counted" + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id) + ) as response: + # Read the priming event and the first notification, then drop the connection. + priming, first = await _read_events(response, 2) + assert (priming.id, first.id) == snapshot(("3", "4")) + last_seen = first.id + release.set() + # The handler keeps running after the disconnect; its remaining messages are stored. + # The first wait returns immediately (the priming and first tick are already stored); + # the second blocks until the response itself is stored so the replay content is fixed. + await store.wait_until_stored(4) + await store.wait_until_stored(8) + replay_headers = base_headers(session_id=session_id) | {"last-event-id": last_seen} + async with http.stream("GET", "/mcp", headers=replay_headers) as replay: + assert replay.status_code == 200 + missed = await _read_events(replay, 3) + + decoded = parse_sse_messages(missed) + # Exactly the two remaining related notifications and the response, with their original IDs. + assert [event.id for event in missed] == snapshot(["5", "6", "8"]) + assert [type(message).__name__ for message in decoded] == snapshot( + ["JSONRPCNotification", "JSONRPCNotification", "JSONRPCResponse"] + ) + assert isinstance(decoded[2], JSONRPCResponse) + assert decoded[2].id == 1 + # The unrelated resource-updated notification was stored under the standalone-stream key, not + # this request's stream, so it must not appear in the replay. + assert all( + not (isinstance(message, JSONRPCNotification) and message.method == "notifications/resources/updated") + for message in decoded + ) + + +@requirement("hosting:resume:bad-event-id") +async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None: + """A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error. + + See the divergence on hosting:resume:bad-event-id: this pins current behaviour. + """ + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + for unknown in ("no-such-event", "0"): + headers = base_headers(session_id=session_id) | {"last-event-id": unknown} + async with http.stream("GET", "/mcp", headers=headers) as replay: + assert replay.status_code == 200 + assert replay.headers["content-type"].startswith("text/event-stream") + events = [event async for event in EventSource(replay).aiter_sse()] + assert events == [] + + +@requirement("hosting:http:disconnect-not-cancel") +async def test_dropping_the_connection_mid_request_does_not_cancel_the_handler() -> None: + """Closing the request's SSE connection while the handler is running leaves the handler running. + + The handler signals when it has started and when it has finished; the test drops the + connection in between and then releases the handler. If the disconnect cancelled the handler, + `finished` would never be set and the test would time out. + """ + started = anyio.Event() + release = anyio.Event() + finished = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Signal start, wait for the test, signal completion.""" + started.set() + await release.wait() + await ctx.info("released") + finished.set() + return "held" + + async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "hold", {}), headers=base_headers(session_id=session_id) + ) as response: + await _read_events(response, 1) + await started.wait() + assert not finished.is_set() + release.set() + await finished.wait() + + +# This test intentionally carries every resumability requirement: the close-then-resume +# scenario is indivisible, so splitting it would mean six near-identical bodies. +@requirement("hosting:resume:close-stream") +@requirement("transport:streamable-http:resumability") +@requirement("client-transport:http:reconnect-post-priming") +@requirement("client-transport:http:reconnect-retry-value") +@requirement("client-transport:http:resume-stream-api") +@requirement("flow:resume:tool-call-resumption-token") +async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: + """A server-closed request stream is reconnected by the client and the call completes. + + The handler emits one notification, closes its own SSE stream, then (once released) emits + another and returns. The client observed the priming event (so it has a Last-Event-ID and a + retry hint of 0ms), sees the stream end, reconnects via GET with Last-Event-ID, and receives + the post-close notification and the result over the replay stream. The shared events make the + test deterministic: the handler only proceeds once the test knows the first notification has + arrived (and so the client's reconnection has begun). + """ + received: list[object] = [] + before_seen = anyio.Event() + gate = anyio.Event() + done = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def interrupt(ctx: Context) -> str: + """Emit, close this call's SSE stream, then emit again after the test releases the gate.""" + await ctx.info("before close") + await ctx.close_sse_stream() + await gate.wait() + await ctx.info("after close") + done.set() + return "resumed" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + if params.data == "before close": + before_seen.set() + + result: list[CallToolResult] = [] + async with connect_over_streamable_http( + mcp, event_store=SequencedEventStore(), retry_interval=0, logging_callback=collect + ) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call() -> None: + result.append(await client.call_tool("interrupt", {})) + + tg.start_soon(call) + await before_seen.wait() + gate.set() + await done.wait() + + assert result == snapshot( + [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] + ) + assert received == snapshot(["before close", "after close"]) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py new file mode 100644 index 0000000000..561fbf251a --- /dev/null +++ b/tests/interaction/transports/test_hosting_session.py @@ -0,0 +1,203 @@ +"""Streamable HTTP session lifecycle: creation, routing, termination, and stateless mode. + +A test here speaks raw HTTP only when its assertion is the wire contract -- which header is +issued, which status code answers which condition -- that the SDK `Client` cannot observe. +Everything else is `Client`-driven against the same mounted session manager. Transport-agnostic +behaviour is covered by the `connect`-fixture matrix. +""" + +import re + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from tests.interaction._connect import ( + base_headers, + client_via_http, + initialize_body, + initialize_via_http, + mounted_app, + post_jsonrpc, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", input_schema={"type": "object"})]) + + return Server("hosted", on_list_tools=list_tools) + + +@requirement("hosting:session:create") +@requirement("hosting:session:id-charset") +async def test_initialize_issues_a_visible_ascii_session_id() -> None: + """An initialize POST without a session ID creates a session and returns a visible-ASCII Mcp-Session-Id.""" + async with mounted_app(_server()) as (http, _): + response, messages = await post_jsonrpc(http, initialize_body()) + + assert response.status_code == 200 + session_id = response.headers.get("mcp-session-id") + assert session_id is not None + # The spec requires the session ID to consist only of visible ASCII (0x21-0x7E). + assert re.fullmatch(r"[\x21-\x7E]+", session_id) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 1 + + +@requirement("hosting:session:reuse") +async def test_subsequent_requests_with_the_session_id_route_to_the_same_session() -> None: + """Requests carrying the issued Mcp-Session-Id reuse that session's transport rather than creating another.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as client: + await client.list_tools() + await client.list_tools() + # The session count is the only signal that distinguishes routing-to-existing from + # silently creating a second session: both produce a successful result. + assert len(manager._server_instances) == 1 + + +@requirement("hosting:session:unknown-id") +async def test_requests_with_an_unknown_session_id_return_404() -> None: + """POST, GET, and DELETE each carrying an unknown Mcp-Session-Id are answered 404 by the manager.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers=base_headers(session_id="not-a-session"), + ) + get = await http.get("/mcp", headers=base_headers(session_id="not-a-session")) + delete = await http.delete("/mcp", headers=base_headers(session_id="not-a-session")) + + assert (post.status_code, post.json()) == snapshot( + (404, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Session not found"}}) + ) + assert (get.status_code, delete.status_code) == (404, 404) + + +@requirement("hosting:session:missing-id") +async def test_non_initialize_post_without_a_session_id_returns_400() -> None: + """A non-initialize POST that omits Mcp-Session-Id in stateful mode is rejected with 400.""" + async with mounted_app(_server()) as (http, _): + await initialize_via_http(http) + response = await http.post( + "/mcp", json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, headers=base_headers() + ) + + assert (response.status_code, response.json()) == snapshot( + (400, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Bad Request: Missing session ID"}}) + ) + + +@requirement("hosting:session:delete") +@requirement("hosting:session:post-termination-404") +async def test_delete_terminates_the_session_and_subsequent_requests_return_404() -> None: + """DELETE with a valid Mcp-Session-Id terminates the session; further requests on that ID return 404.""" + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + + delete = await http.delete("/mcp", headers=base_headers(session_id=session_id)) + assert delete.status_code == 200 + + # The manager keeps the terminated transport registered, so the next request reaches the + # transport's own _terminated check rather than the manager's unknown-session path. + assert session_id in manager._server_instances + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id), + ) + + assert (post.status_code, post.json()) == snapshot( + ( + 404, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, + }, + ) + ) + + +@requirement("hosting:session:isolation") +async def test_terminating_one_session_leaves_others_working() -> None: + """Terminating one session on a manager does not disturb a concurrent session on the same manager.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as survivor: + async with client_via_http(http) as terminated: + await terminated.list_tools() + assert len(manager._server_instances) == 2 + # `terminated` has exited (its DELETE has been sent); `survivor` still answers. + result = await survivor.list_tools() + + assert result.tools[0].name == "noop" + + +@requirement("hosting:session:reinitialize") +async def test_second_initialize_on_an_existing_session_is_accepted() -> None: + """A second initialize POST carrying an existing session ID is processed rather than rejected. + + See the divergence on the requirement: the entry expects a rejection, but the SDK forwards the + second initialize to the running server, which answers it as a fresh handshake. + """ + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + response, messages = await post_jsonrpc(http, initialize_body(request_id=2), session_id=session_id) + assert len(manager._server_instances) == 1 + + assert response.status_code == snapshot(200) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 2 + + +@requirement("hosting:stateless:no-session-id") +@requirement("hosting:stateless:no-reuse") +async def test_stateless_mode_never_issues_a_session_id() -> None: + """A stateless server issues no Mcp-Session-Id and creates no persistent transport. + + The recording proves no request the SDK client sent carried an Mcp-Session-Id (the server + cannot have issued one, or the client would echo it); the empty instance map proves the + manager kept no transport between requests. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_server(), stateless_http=True, on_request=record) as (http, manager): + async with client_via_http(http) as client: + result = await client.list_tools() + assert manager._server_instances == {} + + assert result.tools[0].name == "noop" + assert all("mcp-session-id" not in request.headers for request in requests) + assert "DELETE" not in {request.method for request in requests} + + +@requirement("hosting:stateless:concurrent-clients") +async def test_stateless_mode_serves_concurrent_clients_independently() -> None: + """Two clients connected concurrently to the same stateless app each complete a round trip.""" + results: dict[str, ListToolsResult] = {} + + async with mounted_app(_server(), stateless_http=True) as (http, _): + + async def list_via(label: str) -> None: + async with client_via_http(http) as client: + results[label] = await client.list_tools() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(list_via, "a") + tg.start_soon(list_via, "b") + + assert results["a"].tools[0].name == "noop" + assert results["b"].tools[0].name == "noop" diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py new file mode 100644 index 0000000000..4facadec73 --- /dev/null +++ b/tests/interaction/transports/test_sse.py @@ -0,0 +1,96 @@ +"""Behaviour specific to the legacy HTTP+SSE transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file pins only what is observable on the SSE wiring +itself: the GET-then-POST connection lifecycle, the endpoint event, and how the message endpoint +rejects requests it cannot route to a session. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge. +""" + +import gc +import warnings +from uuid import UUID, uuid4 + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.sse import sse_client +from mcp.server import Server +from mcp.types import EmptyResult +from tests.interaction._connect import BASE_URL, build_sse_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +@requirement("transport:sse") +@requirement("transport:sse:endpoint-event") +async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: + """Connecting opens a GET stream whose first event names the POST endpoint and a fresh + session id; messages POSTed there are answered on that stream, and disconnecting releases the + server's session entry.""" + app, sse = build_sse_app(Server("legacy")) + captured_session_id: list[str] = [] + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client( + f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + ) + with anyio.fail_after(5): + async with Client(transport) as client: + assert len(captured_session_id) == 1 + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert await client.send_ping() == snapshot(EmptyResult()) + + assert sse._read_stream_writers == {} + # See connect_over_sse: collect the one stream sse_starlette never closes on disconnect. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + gc.collect() + + +@requirement("transport:sse:post:session-routing") +async def test_post_without_a_session_id_is_rejected() -> None: + """A POST to the message endpoint with no session_id query parameter is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1}) + assert (response.status_code, response.text) == snapshot((400, "session_id is required")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_with_a_malformed_session_id_is_rejected() -> None: + """A POST whose session_id query parameter is not a UUID is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((400, "Invalid session ID")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_for_an_unknown_session_is_rejected() -> None: + """A POST naming a well-formed session_id that no SSE stream owns is answered 404.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((404, "Could not find session")) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py new file mode 100644 index 0000000000..2d15d61ff8 --- /dev/null +++ b/tests/interaction/transports/test_stdio.py @@ -0,0 +1,140 @@ +"""The stdio transport: one subprocess end-to-end test and one in-process framing test. + +Everything else in the suite runs in a single process; the subprocess test exists to prove the same +client↔server round trip works over the stdio transport's real boundary (a child process whose +stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in +`_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. + +The framing test drives `stdio_server` in-process by passing it injected text streams instead of the +real stdin/stdout, so the raw lines the transport writes can be asserted directly without a process +boundary. + +stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test +would be slow, and the matrix already proves transport-agnosticism over three in-process +transports. Process-lifecycle edge cases (escalation to terminate/kill, parse errors) are covered by +`tests/client/test_stdio.py` and stay deferred here. +""" + +import io +import json +import os +import sys +import tempfile +from pathlib import Path + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, +) +from mcp.types.jsonrpc import jsonrpc_message_adapter +from tests.interaction._connect import initialize_body +from tests.interaction._requirements import requirement +from tests.interaction.transports import _stdio_server + +pytestmark = pytest.mark.anyio + +_REPO_ROOT = Path(__file__).parents[3] + + +@requirement("transport:stdio") +@requirement("transport:stdio:clean-shutdown") +@requirement("transport:stdio:stderr-passthrough") +async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: + """A Client connected over stdio initializes, calls a tool with arguments, receives the + server's log notification before the call returns, and the server exits when the transport + closes its stdin.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + with tempfile.TemporaryFile(mode="w+") as errlog: + transport = stdio_client( + StdioServerParameters( + command=sys.executable, + args=["-m", _stdio_server.__name__], + cwd=str(_REPO_ROOT), + # stdio_client deliberately filters the inherited environment to a safe minimum, + # which drops the variables coverage.py's subprocess support uses; pass them through + # so the server module is measured. Empty when not running under coverage. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + ), + errlog=errlog, + ) + + with anyio.fail_after(10): + async with Client(transport, logging_callback=collect) as client: + assert client.initialize_result.server_info.name == "stdio-echo" + result = await client.call_tool("echo", {"text": "across\nprocesses"}) + + errlog.seek(0) + captured_stderr = errlog.read() + + assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) + # stdio carries one ordered server→client stream, so the same notification-before-response + # guarantee holds here as for the in-memory transport. + assert received == snapshot( + [LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")] + ) + # The server writes this line only after its run loop returns, which happens when stdin closes: + # seeing it proves the process exited on its own rather than via the transport's terminate + # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: + # the transport routes the child's stderr to the caller's `errlog` without consuming it. + assert captured_stderr == snapshot("stdio-echo: clean exit\n") + + +@requirement("transport:stdio:stream-purity") +@requirement("transport:stdio:no-embedded-newlines") +async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: + """Everything `stdio_server` writes is a valid JSON-RPC message on its own line, and nothing else. + + The transport's stdin/stdout parameters are public, so the test injects in-process text streams + instead of the real process handles and drives the read/write streams directly: a JSON-RPC line on + stdin is parsed and delivered, and every message sent on the write stream appears as exactly one + newline-terminated line whose payload newlines are JSON-escaped. This proves the transport's own + framing; it does not guard `sys.stdout` against handler code that prints to it directly (see the + divergence on `transport:stdio:stream-purity`). + """ + captured = io.StringIO() + sent_line = json.dumps(initialize_body(request_id=1)) + "\n" + + with anyio.fail_after(5): + async with stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + read_stream, + write_stream, + ): + async with read_stream, write_stream: + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "initialize" + + response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) + notification = JSONRPCNotification( + jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} + ) + await write_stream.send(SessionMessage(response)) + await write_stream.send(SessionMessage(notification)) + + output = captured.getvalue() + assert output.endswith("\n") + lines = output.removesuffix("\n").split("\n") + assert len(lines) == 2 + messages = [jsonrpc_message_adapter.validate_json(line) for line in lines] + assert [type(message).__name__ for message in messages] == snapshot(["JSONRPCResponse", "JSONRPCNotification"]) + # The newline inside the payload is JSON-escaped on the wire, not a literal newline that would + # break the one-message-per-line framing. + assert r"line\nbreak" in lines[0] + assert r"two\nlines" in lines[1] diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py new file mode 100644 index 0000000000..72af075770 --- /dev/null +++ b/tests/interaction/transports/test_streamable_http.py @@ -0,0 +1,169 @@ +"""Behaviour specific to the streamable HTTP transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file only pins what cannot be observed in memory: the +server's stateless and JSON-response modes, the standalone GET stream, and the full-duplex +server-initiated exchange on a still-open call. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge — no sockets, threads, or subprocesses. +""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + ElicitRequestParams, + ElicitResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, +) +from tests.interaction._connect import connect_over_streamable_http +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _smoke_server() -> MCPServer: + """A server exercising each message shape the transport-specific tests need.""" + mcp = MCPServer("smoke", instructions="Talk to the smoke server.") + + @mcp.tool() + def echo(text: str) -> str: + """Echo the text back.""" + return text + + class Confirmation(BaseModel): + confirmed: bool + + @mcp.tool() + async def ask(ctx: Context) -> str: + """Elicit a confirmation from the client and report the outcome.""" + answer = await ctx.elicit("Proceed?", Confirmation) + # In stateless mode the elicit raises before this point: there is no session to call back through. + assert isinstance(answer, AcceptedElicitation) + return f"confirmed={answer.data.confirmed}" + + @mcp.tool() + async def announce(ctx: Context) -> str: + """Send one notification related to this request and one that is not.""" + await ctx.info("about to announce") + await ctx.session.send_resource_updated("file:///watched.txt") + return "announced" + + return mcp + + +@requirement("transport:streamable-http:json-response") +@requirement("hosting:http:json-response-mode") +@requirement("client-transport:http:json-response-parsed") +async def test_tool_call_over_streamable_http_with_json_responses() -> None: + """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" + async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: + assert client.initialize_result.server_info.name == "smoke" + result = await client.call_tool("echo", {"text": "as json"}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="as json")], structured_content={"result": "as json"}) + ) + + +@requirement("transport:streamable-http:stateless") +async def test_tool_calls_over_stateless_streamable_http() -> None: + """Consecutive requests each succeed against a stateless server with no session to share.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + first = await client.call_tool("echo", {"text": "first"}) + second = await client.call_tool("echo", {"text": "second"}) + + assert first == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + assert second == snapshot( + CallToolResult(content=[TextContent(text="second")], structured_content={"result": "second"}) + ) + + +@requirement("transport:streamable-http:stateless-restrictions") +async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: + """A handler that tries to call back to the client in stateless mode fails: there is no session.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + result = await client.call_tool("ask", {}) + + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error + # path; pin the stable prefix rather than the full exception prose. + assert result.content[0].text.startswith("Error executing tool ask:") + + +@requirement("transport:streamable-http:notifications") +@requirement("transport:streamable-http:unrelated-messages") +@requirement("hosting:http:standalone-sse") +async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: + """A server message with no related request reaches the client through the standalone GET stream. + + The log notification is related to the tool call and travels on that call's own SSE stream; + the resource-updated notification is not related to any request, so the only way it can reach + the client is the standalone stream the client opens after initialization. Delivery order + across the two streams is not guaranteed, so the unrelated message is awaited rather than + assumed to beat the tool result. + """ + received: list[IncomingMessage] = [] + resource_update_seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + if isinstance(message, ResourceUpdatedNotification): + resource_update_seen.set() + + async with connect_over_streamable_http(_smoke_server(), message_handler=collect) as client: + result = await client.call_tool("announce", {}) + with anyio.fail_after(5): + await resource_update_seen.wait() + + assert result == snapshot( + CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) + ) + # The related log notification rides the call's stream; the unrelated resource-updated + # notification rides the standalone stream. Both arrive, nothing else does. + assert [message for message in received if isinstance(message, LoggingMessageNotification)] == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] + ) + assert [message for message in received if isinstance(message, ResourceUpdatedNotification)] == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) + assert len(received) == 2 + + +@requirement("transport:streamable-http:stateful") +@requirement("transport:streamable-http:server-to-client") +async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> None: + """An elicitation issued mid-call reaches the client and its answer reaches the handler over stateful HTTP. + + The elicitation request travels on the still-open SSE response of the tool call that triggered + it, and the client's answer arrives as a separate POST -- the full-duplex exchange the + streamable HTTP transport exists to provide. + """ + asked: list[ElicitRequestParams] = [] + + async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params) + return ElicitResult(action="accept", content={"confirmed": True}) + + async with connect_over_streamable_http(_smoke_server(), elicitation_callback=answer) as client: + # Bounded because a harness regression here historically meant deadlock, not failure. + with anyio.fail_after(5): + result = await client.call_tool("ask", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) + ) + assert [params.message for params in asked] == snapshot(["Proceed?"])