From 41ab75dc8058dbf9b25ed62b6f08c77990de837f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 15:49:55 +0000 Subject: [PATCH] Bind transport sessions to the authenticated principal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both HTTP transports now record the principal that created each session — the OAuth client together with the issuer and subject when the token verifier supplies them — and serve subsequent requests for that session only when they present the same principal. Requests presenting a different principal receive the same 404 response as for an unknown session ID, and SSE session entries are removed when the connection ends. Servers without authentication, and authentication backends other than the built-in BearerAuthBackend, are unaffected: no principal is recorded and the comparison always passes. The new in-process SSE tests bring connect_sse, handle_post_message, and TransportSecurityMiddleware under tracked coverage, so the corresponding no-cover pragmas are removed. --- src/mcp/server/auth/middleware/bearer_auth.py | 26 +- src/mcp/server/sse.py | 66 ++-- src/mcp/server/streamable_http_manager.py | 40 ++- src/mcp/server/transport_security.py | 14 +- tests/server/test_sse_security.py | 288 +++++++++++++++++- tests/server/test_streamable_http_manager.py | 167 +++++++++- tests/server/test_transport_security.py | 88 ++++++ 7 files changed, 647 insertions(+), 42 deletions(-) create mode 100644 tests/server/test_transport_security.py diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 2eafdc793e..ba66e94226 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,6 +1,6 @@ import json import time -from typing import Any +from typing import Any, TypedDict from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class AuthorizationContext(TypedDict): + client_id: str + issuer: str | None + subject: str | None + + +def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: + """Identify the principal `user` represents, for transports to compare + against the principal that created a session. Components the token + verifier does not supply are `None`, so the comparison degrades to the + remaining components. + + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for + a verifier that populates `subject` and `claims` from an introspection + response.""" + token = user.access_token + issuer = (token.claims or {}).get("iss") + return AuthorizationContext( + client_id=token.client_id, + issuer=str(issuer) if issuer is not None else None, + subject=token.subject, + ) + + class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index be8e979c9d..05e948332b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -50,6 +50,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send from mcp import types +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, @@ -73,6 +74,9 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] + # Identity of the credential that created each session; requests for a + # session must present the same credential. + _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} + self._session_owners = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] != "http": # pragma: no cover + if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: # pragma: no cover + if error_response: await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + self._session_owners[session_id] = authorization_context(user) self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") @@ -169,27 +177,30 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await sse_stream_reader.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - self._read_stream_writers.pop(session_id, None) - logging.debug(f"Client session disconnected {session_id}") + try: + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + await sse_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + finally: + self._read_stream_writers.pop(session_id, None) + self._session_owners.pop(session_id, None) async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") @@ -197,7 +208,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: # pragma: no cover + if error_response: return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + if requestor != self._session_owners.get(session_id): + # A session can only be used with the credential that created it. + # Respond exactly as if the session did not exist. + logger.warning("Rejecting message for session %s: credential does not match", session_id) + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + body = await request.body() logger.debug(f"Received JSON: {body}") try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: # pragma: no cover + except ValidationError as err: logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 39d434505c..81350a8f24 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -5,7 +5,6 @@ import contextlib import logging from collections.abc import AsyncIterator -from http import HTTPStatus from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -15,6 +14,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, EventStore, @@ -89,6 +89,9 @@ def __init__( # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Identity of the credential that created each session; requests for a + # session must present the same credential. + self._session_owners: dict[str, AuthorizationContext] = {} # The task group will be set during lifespan self._task_group = None @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_owners.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] + if requestor != self._session_owners.get(request_mcp_session_id): + # A session can only be used with the credential that created + # it. Respond exactly as if the session did not exist. + logger.warning( + "Rejecting request for session %s: credential does not match the one that created the session", + request_mcp_session_id[:64], + ) + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") + ) + response = Response( + body.model_dump_json(by_alias=True, exclude_unset=True), + status_code=404, + media_type="application/json", + ) + await response(scope, receive, send) + return logger.debug("Session already exists, handling request directly") # Push back idle deadline on activity if transport.idle_scope is not None and self.session_idle_timeout is not None: @@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S ) assert http_transport.mcp_session_id is not None + if requestor is not None: + self._session_owners[http_transport.mcp_session_id] = requestor self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") @@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE assert http_transport.mcp_session_id is not None logger.info(f"Session {http_transport.mcp_session_id} idle timeout") self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_owners.pop(http_transport.mcp_session_id, None) await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") @@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE f"{http_transport.mcp_session_id} from active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_owners.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # TODO: Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=None, - error=ErrorData(code=INVALID_REQUEST, message="Session not found"), + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_unset=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", + body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json" ) await response(scope, receive, send) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 707d4b61dd..d9e9f965b3 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -42,17 +42,17 @@ def __init__(self, settings: TransportSecuritySettings | None = None): def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: # pragma: no cover + if not host: logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: # pragma: no cover + if host in self.settings.allowed_hosts: return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -65,16 +65,16 @@ def _validate_host(self, host: str | None) -> bool: def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: # pragma: no cover + if not origin: return True # Check exact match first - if origin in self.settings.allowed_origins: # pragma: no cover + if origin in self.settings.allowed_origins: return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..e95dc51b31 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,27 +1,44 @@ -"""Tests for SSE server DNS rebinding protection.""" +"""Tests for SSE server request validation.""" import logging import multiprocessing +import re import socket +import anyio import httpx import pytest +import sse_starlette.sse import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Message, Receive, Scope, Send from mcp.server import Server +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool +from mcp.shared._stream_protocols import WriteStream +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +@pytest.fixture(autouse=True) +def reset_sse_starlette_exit_event() -> None: + """sse-starlette<2 caches a module-level anyio.Event on AppStatus; reset it + between tests so it is not bound to a previous test's event loop.""" + app_status = getattr(sse_starlette.sse, "AppStatus", None) + if app_status is not None and hasattr(app_status, "should_exit_event"): # pragma: lax no cover + app_status.should_exit_event = None + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -291,3 +308,270 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _sse_scope( + method: str, path: str, user: AuthenticatedUser | None, *, query_string: bytes = b"", body: bytes = b"" +) -> tuple[Scope, Receive, Send, list[Message]]: + """Build an ASGI scope/receive/send triple for a request to the SSE transport.""" + scope: Scope = { + "type": "http", + "method": method, + "path": path, + "root_path": "", + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + } + if user is not None: + scope["user"] = user + sent: list[Message] = [] + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + return scope, receive, send, sent + + +def _response_status(sent: list[Message]) -> int: + response_start = next(msg for msg in sent if msg["type"] == "http.response.start") + return response_start["status"] + + +async def _post_message(transport: SseServerTransport, session_id: str, user: AuthenticatedUser | None) -> int: + """POST a message to an SSE session as `user` and return the response status.""" + body = b'{"jsonrpc": "2.0", "id": 1, "method": "ping", "params": null}' + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", user, query_string=f"session_id={session_id}".encode(), body=body + ) + await transport.handle_post_message(scope, receive, send) + return _response_status(sent) + + +_Principal = tuple[str] | tuple[str, str] | tuple[str, str, str] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("creator", "sender", "expected"), + [ + pytest.param(("client-a",), ("client-b",), 404, id="different-client"), + pytest.param(("client-a",), None, 404, id="unauthenticated-sender"), + pytest.param(("client-a", "alice"), ("client-a", "bob"), 404, id="same-client-different-subject"), + pytest.param(("client-a", "alice"), ("client-a",), 404, id="same-client-no-subject"), + pytest.param( + ("client-a", "alice", "https://i1"), ("client-a", "alice", "https://i2"), 404, id="different-issuer" + ), + pytest.param(None, ("client-a",), 404, id="unauthenticated-creator"), + pytest.param(("client-a",), ("client-a",), 202, id="same-client"), + pytest.param(("client-a", "alice"), ("client-a", "alice"), 202, id="same-client-and-subject"), + pytest.param(None, None, 202, id="both-unauthenticated"), + ], +) +async def test_sse_post_requires_the_credential_that_created_the_session( + creator: _Principal | None, + sender: _Principal | None, + expected: int, +): + """The session endpoint URL issued to one authenticated principal must not + accept messages from a request authenticated as a different one.""" + transport = SseServerTransport("/messages/") + session_id_received = anyio.Event() + session_ids: list[str] = [] + client_disconnected = anyio.Event() + + async def get_send(message: Message) -> None: + # The first body chunk is the SSE event announcing the session URI to POST messages to. + if message["type"] == "http.response.body" and not session_ids: + match = re.search(rb"session_id=([0-9a-f]{32})", message.get("body", b"")) + assert match is not None, f"expected the endpoint event first, got {message!r}" + session_ids.append(match.group(1).decode()) + session_id_received.set() + + async def get_receive() -> Message: + # The SSE client stays connected until the test signals otherwise. + await client_disconnected.wait() + return {"type": "http.disconnect"} + + creator_user = _authenticated_user(*creator) if creator is not None else None + sender_user = _authenticated_user(*sender) if sender is not None else None + + async def hold_sse_connection() -> None: + """Establish the SSE session as `creator` and keep it open, as a server would.""" + scope, _, _, _ = _sse_scope("GET", "/sse", creator_user) + with anyio.fail_after(5): + async with transport.connect_sse(scope, get_receive, get_send) as (read_stream, write_stream): + async with read_stream, write_stream: # pragma: no branch + async for _ in read_stream: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(hold_sse_connection) + with anyio.fail_after(5): + await session_id_received.wait() + + assert await _post_message(transport, session_ids[0], sender_user) == expected + + client_disconnected.set() + + # Once the connection is gone the session is no longer routable. + assert await _post_message(transport, session_ids[0], creator_user) == 404 + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_non_http_scope(): + """connect_sse refuses ASGI scopes that are not HTTP requests.""" + transport = SseServerTransport("/messages/") + with pytest.raises(ValueError): + async with transport.connect_sse({"type": "websocket"}, _no_receive, _no_send): + raise NotImplementedError + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_disallowed_host(): + """connect_sse rejects requests whose Host header fails the configured security check.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("GET", "/sse", None) + scope["headers"] = [(b"host", b"disallowed.example.com")] + + with pytest.raises(ValueError): + async with transport.connect_sse(scope, receive, send): + raise NotImplementedError + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_post_without_a_session_id_returns_400(): + """POSTs to the messages endpoint must include a session_id query parameter.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_malformed_session_id_returns_400(): + """A session_id that is not 32 hex characters is rejected before any session lookup.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None, query_string=b"session_id=not-hex") + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_disallowed_host_is_rejected_before_session_lookup(): + """The transport security check on POST runs before any session-ID handling.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + scope["headers"] = [(b"host", b"disallowed.example.com"), (b"content-type", b"application/json")] + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_round_trip_delivers_posted_messages_and_streams_responses(): + """A POSTed JSON-RPC message reaches the server's read stream, and a message + written to the server's write stream is sent to the client as an SSE event.""" + transport = SseServerTransport("/messages/") + session = _SseSession(transport) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(session.hold) + await session.ready.wait() + + # POST a parse-failing body: client gets 400, server's read stream receives the error. + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", None, query_string=f"session_id={session.session_id}".encode(), body=b"not json" + ) + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + assert isinstance(await session.next_read_item(), Exception) + + # POST a valid message: client gets 202, server's read stream receives it. + assert await _post_message(transport, session.session_id, None) == 202 + received = await session.next_read_item() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "ping" + + # Server writes a response: it appears as an SSE `message` event on the GET stream. + outgoing = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + await session.write_stream.send(SessionMessage(outgoing)) + chunk = await session.next_body_chunk() + assert b"event: message" in chunk + assert outgoing.model_dump_json(by_alias=True, exclude_unset=True).encode() in chunk + + session.disconnect() + + +class _SseSession: + """Drive an in-process SSE GET connection and surface what the server reads and the client receives. + + `hold` runs the connection in a background task and consumes the server-side read stream + into a buffer so that `handle_post_message` (which writes to that stream with a zero-capacity + channel) never blocks the test body. + """ + + def __init__(self, transport: SseServerTransport) -> None: + self.transport = transport + self.ready = anyio.Event() + self._disconnected = anyio.Event() + self._body_send, self._body_recv = anyio.create_memory_object_stream[bytes](16) + self._read_send, self._read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](16) + self.session_id = "" + self.write_stream: WriteStream[SessionMessage] + + async def hold(self) -> None: + scope, _, _, _ = _sse_scope("GET", "/sse", None) + async with self.transport.connect_sse(scope, self._receive, self._send) as (read, write): + self.write_stream = write + async with read, write, self._body_send, self._body_recv, self._read_send, self._read_recv: + async for item in read: + await self._read_send.send(item) + + def disconnect(self) -> None: + self._disconnected.set() + + async def next_read_item(self) -> SessionMessage | Exception: + return await self._read_recv.receive() + + async def next_body_chunk(self) -> bytes: + return await self._body_recv.receive() + + async def _receive(self) -> Message: + await self._disconnected.wait() + return {"type": "http.disconnect"} + + async def _send(self, message: Message) -> None: + if message["type"] != "http.response.body": + return + body: bytes = message.get("body", b"") + if not self.session_id: + match = re.search(rb"session_id=([0-9a-f]{32})", body) + assert match is not None, f"expected the endpoint event first, got {message!r}" + self.session_id = match.group(1).decode() + self.ready.set() + else: + await self._body_send.send(body) + + +async def _no_receive() -> Message: + raise NotImplementedError + + +async def _no_send(message: Message) -> None: + raise NotImplementedError diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a4..ba75547964 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -8,11 +8,13 @@ import anyio import httpx import pytest -from starlette.types import Message +from starlette.types import Message, Scope from mcp import Client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext, streamable_http_manager +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @@ -413,3 +415,166 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def _user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _request_scope( + *, session_id: str | None = None, user: AuthenticatedUser | None = None, method: str = "POST" +) -> Scope: + """Build an ASGI scope for a request to the MCP endpoint.""" + headers = [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + scope: Scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": headers, + } + if user is not None: + scope["user"] = user + return scope + + +async def _open_session(manager: StreamableHTTPSessionManager, user: AuthenticatedUser | None) -> str: + """Create a new session as `user` and return its session ID.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(_request_scope(user=user), mock_receive, mock_send) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + headers = dict(response_start.get("headers", [])) + return headers[MCP_SESSION_ID_HEADER.encode()].decode() + + +async def _request_session( + manager: StreamableHTTPSessionManager, session_id: str, user: AuthenticatedUser | None, method: str = "POST" +) -> int: + """Send a request for an existing session as `user` and return the response status.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request( + _request_scope(session_id=session_id, user=user, method=method), mock_receive, mock_send + ) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + return response_start["status"] + + +@pytest.fixture +async def manager_with_live_session(): + """A running manager around a real `Server`. Sessions remain registered until + `manager.run()` exits because `Server.run` blocks waiting for an initialize message.""" + manager = StreamableHTTPSessionManager(app=Server("test-session-credentials")) + async with manager.run(): + yield manager + + +@pytest.mark.anyio +async def test_session_accepts_requests_from_the_credential_that_created_it( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Requests presenting the same credential as the one that created the session are served.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + status = await _request_session(manager, session_id, _user("client-a")) + + # The request passes the manager's credential check and reaches the + # session's transport, instead of being answered with 404 by the manager. + assert status != 404 + + +@pytest.mark.anyio +@pytest.mark.parametrize("method", ["POST", "GET", "DELETE"]) +async def test_session_rejects_requests_from_a_different_credential( + manager_with_live_session: StreamableHTTPSessionManager, method: str +) -> None: + """A session created by one credential cannot be used with another credential, whatever the method.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, _user("client-b"), method) == 404 + # The session is still registered and still serves its creator. + assert await _request_session(manager, session_id, _user("client-a")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_from_a_different_subject_of_the_same_client( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Two end-users that share an OAuth client cannot use each other's sessions.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a", subject="alice")) + + assert await _request_session(manager, session_id, _user("client-a", subject="bob")) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject=None)) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_with_the_same_subject_from_a_different_issuer( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A subject is unique only per issuer, so a colliding subject from a different issuer is not the same principal.""" + manager = manager_with_live_session + creator = _user("client-a", subject="alice", issuer="https://issuer.one") + session_id = await _open_session(manager, creator) + + other_issuer = _user("client-a", subject="alice", issuer="https://issuer.two") + assert await _request_session(manager, session_id, other_issuer) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) == 404 + assert await _request_session(manager, session_id, creator) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_unauthenticated_requests_for_an_authenticated_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created with a credential cannot be used without one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, None) == 404 + + +@pytest.mark.anyio +async def test_session_rejects_authenticated_requests_for_an_anonymous_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created without a credential cannot be used with one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, _user("client-a")) == 404 + + +@pytest.mark.anyio +async def test_anonymous_session_accepts_anonymous_requests( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Servers without authentication keep working: no credential on either side.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, None) != 404 diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 0000000000..be28980b53 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,88 @@ +"""Tests for the transport-security request validation middleware.""" + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def _request(host: str | None, origin: str | None, content_type: str | None = "application/json") -> Request: + headers: list[tuple[bytes, bytes]] = [] + if content_type is not None: + headers.append((b"content-type", content_type.encode())) + if host is not None: + headers.append((b"host", host.encode())) + if origin is not None: + headers.append((b"origin", origin.encode())) + return Request({"type": "http", "method": "GET", "headers": headers}) + + +SETTINGS = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["good.example", "wild.example:*"], + allowed_origins=["http://good.example", "http://wild.example:*"], +) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("host", "origin", "expected"), + [ + pytest.param(None, None, 421, id="missing-host"), + pytest.param("evil.example", None, 421, id="host-no-match"), + pytest.param("evil.example:9000", None, 421, id="host-wildcard-base-mismatch"), + pytest.param("good.example", None, None, id="host-exact-no-origin"), + pytest.param("wild.example:9000", None, None, id="host-wildcard-match"), + pytest.param("good.example", "http://evil.example", 403, id="origin-no-match"), + pytest.param("good.example", "http://evil.example:9000", 403, id="origin-wildcard-base-mismatch"), + pytest.param("good.example", "http://good.example", None, id="origin-exact"), + pytest.param("good.example", "http://wild.example:9000", None, id="origin-wildcard-match"), + ], +) +async def test_validate_request_checks_host_then_origin( + host: str | None, origin: str | None, expected: int | None +) -> None: + """Host is checked first, then Origin; exact and wildcard-port allowlist entries are honoured.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request(host, origin)) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_skips_host_and_origin_when_protection_is_disabled() -> None: + """With DNS-rebinding protection off, any Host/Origin is accepted.""" + middleware = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +async def test_validate_request_defaults_to_protection_disabled() -> None: + """Constructing the middleware without settings leaves DNS-rebinding protection off.""" + middleware = TransportSecurityMiddleware() + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("content_type", "expected"), + [ + pytest.param("application/json", None, id="json"), + pytest.param("application/json; charset=utf-8", None, id="json-with-charset"), + pytest.param("APPLICATION/JSON", None, id="case-insensitive"), + pytest.param("text/plain", 400, id="wrong-type"), + pytest.param(None, 400, id="missing"), + ], +) +async def test_validate_request_checks_content_type_on_post(content_type: str | None, expected: int | None) -> None: + """POST requests must carry an application/json Content-Type, regardless of DNS-rebinding settings.""" + middleware = TransportSecurityMiddleware() + response = await middleware.validate_request(_request("any", None, content_type=content_type), is_post=True) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_ignores_content_type_on_get() -> None: + """Content-Type is only enforced for POST requests.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request("good.example", None, content_type="text/plain")) + assert response is None