diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 2eafdc793..ba66e9422 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,6 +1,6 @@ import json import time -from typing import Any +from typing import Any, TypedDict from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class AuthorizationContext(TypedDict): + client_id: str + issuer: str | None + subject: str | None + + +def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: + """Identify the principal `user` represents, for transports to compare + against the principal that created a session. Components the token + verifier does not supply are `None`, so the comparison degrades to the + remaining components. + + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for + a verifier that populates `subject` and `claims` from an introspection + response.""" + token = user.access_token + issuer = (token.claims or {}).get("iss") + return AuthorizationContext( + client_id=token.client_id, + issuer=str(issuer) if issuer is not None else None, + subject=token.subject, + ) + + class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index be8e979c9..05e948332 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -50,6 +50,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send from mcp import types +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, @@ -73,6 +74,9 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] + # Identity of the credential that created each session; requests for a + # session must present the same credential. + _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} + self._session_owners = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] != "http": # pragma: no cover + if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: # pragma: no cover + if error_response: await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + self._session_owners[session_id] = authorization_context(user) self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") @@ -169,27 +177,30 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await sse_stream_reader.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - self._read_stream_writers.pop(session_id, None) - logging.debug(f"Client session disconnected {session_id}") + try: + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + await sse_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + finally: + self._read_stream_writers.pop(session_id, None) + self._session_owners.pop(session_id, None) async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") @@ -197,7 +208,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: # pragma: no cover + if error_response: return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + if requestor != self._session_owners.get(session_id): + # A session can only be used with the credential that created it. + # Respond exactly as if the session did not exist. + logger.warning("Rejecting message for session %s: credential does not match", session_id) + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + body = await request.body() logger.debug(f"Received JSON: {body}") try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: # pragma: no cover + except ValidationError as err: logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 39d434505..81350a8f2 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -5,7 +5,6 @@ import contextlib import logging from collections.abc import AsyncIterator -from http import HTTPStatus from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -15,6 +14,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, EventStore, @@ -89,6 +89,9 @@ def __init__( # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Identity of the credential that created each session; requests for a + # session must present the same credential. + self._session_owners: dict[str, AuthorizationContext] = {} # The task group will be set during lifespan self._task_group = None @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_owners.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] + if requestor != self._session_owners.get(request_mcp_session_id): + # A session can only be used with the credential that created + # it. Respond exactly as if the session did not exist. + logger.warning( + "Rejecting request for session %s: credential does not match the one that created the session", + request_mcp_session_id[:64], + ) + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") + ) + response = Response( + body.model_dump_json(by_alias=True, exclude_unset=True), + status_code=404, + media_type="application/json", + ) + await response(scope, receive, send) + return logger.debug("Session already exists, handling request directly") # Push back idle deadline on activity if transport.idle_scope is not None and self.session_idle_timeout is not None: @@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S ) assert http_transport.mcp_session_id is not None + if requestor is not None: + self._session_owners[http_transport.mcp_session_id] = requestor self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") @@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE assert http_transport.mcp_session_id is not None logger.info(f"Session {http_transport.mcp_session_id} idle timeout") self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_owners.pop(http_transport.mcp_session_id, None) await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") @@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE f"{http_transport.mcp_session_id} from active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_owners.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # TODO: Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=None, - error=ErrorData(code=INVALID_REQUEST, message="Session not found"), + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_unset=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", + body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json" ) await response(scope, receive, send) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 707d4b61d..d9e9f965b 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -42,17 +42,17 @@ def __init__(self, settings: TransportSecuritySettings | None = None): def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: # pragma: no cover + if not host: logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: # pragma: no cover + if host in self.settings.allowed_hosts: return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -65,16 +65,16 @@ def _validate_host(self, host: str | None) -> bool: def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: # pragma: no cover + if not origin: return True # Check exact match first - if origin in self.settings.allowed_origins: # pragma: no cover + if origin in self.settings.allowed_origins: return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a2..e95dc51b3 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,27 +1,44 @@ -"""Tests for SSE server DNS rebinding protection.""" +"""Tests for SSE server request validation.""" import logging import multiprocessing +import re import socket +import anyio import httpx import pytest +import sse_starlette.sse import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Message, Receive, Scope, Send from mcp.server import Server +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool +from mcp.shared._stream_protocols import WriteStream +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +@pytest.fixture(autouse=True) +def reset_sse_starlette_exit_event() -> None: + """sse-starlette<2 caches a module-level anyio.Event on AppStatus; reset it + between tests so it is not bound to a previous test's event loop.""" + app_status = getattr(sse_starlette.sse, "AppStatus", None) + if app_status is not None and hasattr(app_status, "should_exit_event"): # pragma: lax no cover + app_status.should_exit_event = None + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -291,3 +308,270 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _sse_scope( + method: str, path: str, user: AuthenticatedUser | None, *, query_string: bytes = b"", body: bytes = b"" +) -> tuple[Scope, Receive, Send, list[Message]]: + """Build an ASGI scope/receive/send triple for a request to the SSE transport.""" + scope: Scope = { + "type": "http", + "method": method, + "path": path, + "root_path": "", + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + } + if user is not None: + scope["user"] = user + sent: list[Message] = [] + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + return scope, receive, send, sent + + +def _response_status(sent: list[Message]) -> int: + response_start = next(msg for msg in sent if msg["type"] == "http.response.start") + return response_start["status"] + + +async def _post_message(transport: SseServerTransport, session_id: str, user: AuthenticatedUser | None) -> int: + """POST a message to an SSE session as `user` and return the response status.""" + body = b'{"jsonrpc": "2.0", "id": 1, "method": "ping", "params": null}' + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", user, query_string=f"session_id={session_id}".encode(), body=body + ) + await transport.handle_post_message(scope, receive, send) + return _response_status(sent) + + +_Principal = tuple[str] | tuple[str, str] | tuple[str, str, str] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("creator", "sender", "expected"), + [ + pytest.param(("client-a",), ("client-b",), 404, id="different-client"), + pytest.param(("client-a",), None, 404, id="unauthenticated-sender"), + pytest.param(("client-a", "alice"), ("client-a", "bob"), 404, id="same-client-different-subject"), + pytest.param(("client-a", "alice"), ("client-a",), 404, id="same-client-no-subject"), + pytest.param( + ("client-a", "alice", "https://i1"), ("client-a", "alice", "https://i2"), 404, id="different-issuer" + ), + pytest.param(None, ("client-a",), 404, id="unauthenticated-creator"), + pytest.param(("client-a",), ("client-a",), 202, id="same-client"), + pytest.param(("client-a", "alice"), ("client-a", "alice"), 202, id="same-client-and-subject"), + pytest.param(None, None, 202, id="both-unauthenticated"), + ], +) +async def test_sse_post_requires_the_credential_that_created_the_session( + creator: _Principal | None, + sender: _Principal | None, + expected: int, +): + """The session endpoint URL issued to one authenticated principal must not + accept messages from a request authenticated as a different one.""" + transport = SseServerTransport("/messages/") + session_id_received = anyio.Event() + session_ids: list[str] = [] + client_disconnected = anyio.Event() + + async def get_send(message: Message) -> None: + # The first body chunk is the SSE event announcing the session URI to POST messages to. + if message["type"] == "http.response.body" and not session_ids: + match = re.search(rb"session_id=([0-9a-f]{32})", message.get("body", b"")) + assert match is not None, f"expected the endpoint event first, got {message!r}" + session_ids.append(match.group(1).decode()) + session_id_received.set() + + async def get_receive() -> Message: + # The SSE client stays connected until the test signals otherwise. + await client_disconnected.wait() + return {"type": "http.disconnect"} + + creator_user = _authenticated_user(*creator) if creator is not None else None + sender_user = _authenticated_user(*sender) if sender is not None else None + + async def hold_sse_connection() -> None: + """Establish the SSE session as `creator` and keep it open, as a server would.""" + scope, _, _, _ = _sse_scope("GET", "/sse", creator_user) + with anyio.fail_after(5): + async with transport.connect_sse(scope, get_receive, get_send) as (read_stream, write_stream): + async with read_stream, write_stream: # pragma: no branch + async for _ in read_stream: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(hold_sse_connection) + with anyio.fail_after(5): + await session_id_received.wait() + + assert await _post_message(transport, session_ids[0], sender_user) == expected + + client_disconnected.set() + + # Once the connection is gone the session is no longer routable. + assert await _post_message(transport, session_ids[0], creator_user) == 404 + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_non_http_scope(): + """connect_sse refuses ASGI scopes that are not HTTP requests.""" + transport = SseServerTransport("/messages/") + with pytest.raises(ValueError): + async with transport.connect_sse({"type": "websocket"}, _no_receive, _no_send): + raise NotImplementedError + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_disallowed_host(): + """connect_sse rejects requests whose Host header fails the configured security check.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("GET", "/sse", None) + scope["headers"] = [(b"host", b"disallowed.example.com")] + + with pytest.raises(ValueError): + async with transport.connect_sse(scope, receive, send): + raise NotImplementedError + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_post_without_a_session_id_returns_400(): + """POSTs to the messages endpoint must include a session_id query parameter.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_malformed_session_id_returns_400(): + """A session_id that is not 32 hex characters is rejected before any session lookup.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None, query_string=b"session_id=not-hex") + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_disallowed_host_is_rejected_before_session_lookup(): + """The transport security check on POST runs before any session-ID handling.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + scope["headers"] = [(b"host", b"disallowed.example.com"), (b"content-type", b"application/json")] + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_round_trip_delivers_posted_messages_and_streams_responses(): + """A POSTed JSON-RPC message reaches the server's read stream, and a message + written to the server's write stream is sent to the client as an SSE event.""" + transport = SseServerTransport("/messages/") + session = _SseSession(transport) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(session.hold) + await session.ready.wait() + + # POST a parse-failing body: client gets 400, server's read stream receives the error. + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", None, query_string=f"session_id={session.session_id}".encode(), body=b"not json" + ) + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + assert isinstance(await session.next_read_item(), Exception) + + # POST a valid message: client gets 202, server's read stream receives it. + assert await _post_message(transport, session.session_id, None) == 202 + received = await session.next_read_item() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "ping" + + # Server writes a response: it appears as an SSE `message` event on the GET stream. + outgoing = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + await session.write_stream.send(SessionMessage(outgoing)) + chunk = await session.next_body_chunk() + assert b"event: message" in chunk + assert outgoing.model_dump_json(by_alias=True, exclude_unset=True).encode() in chunk + + session.disconnect() + + +class _SseSession: + """Drive an in-process SSE GET connection and surface what the server reads and the client receives. + + `hold` runs the connection in a background task and consumes the server-side read stream + into a buffer so that `handle_post_message` (which writes to that stream with a zero-capacity + channel) never blocks the test body. + """ + + def __init__(self, transport: SseServerTransport) -> None: + self.transport = transport + self.ready = anyio.Event() + self._disconnected = anyio.Event() + self._body_send, self._body_recv = anyio.create_memory_object_stream[bytes](16) + self._read_send, self._read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](16) + self.session_id = "" + self.write_stream: WriteStream[SessionMessage] + + async def hold(self) -> None: + scope, _, _, _ = _sse_scope("GET", "/sse", None) + async with self.transport.connect_sse(scope, self._receive, self._send) as (read, write): + self.write_stream = write + async with read, write, self._body_send, self._body_recv, self._read_send, self._read_recv: + async for item in read: + await self._read_send.send(item) + + def disconnect(self) -> None: + self._disconnected.set() + + async def next_read_item(self) -> SessionMessage | Exception: + return await self._read_recv.receive() + + async def next_body_chunk(self) -> bytes: + return await self._body_recv.receive() + + async def _receive(self) -> Message: + await self._disconnected.wait() + return {"type": "http.disconnect"} + + async def _send(self, message: Message) -> None: + if message["type"] != "http.response.body": + return + body: bytes = message.get("body", b"") + if not self.session_id: + match = re.search(rb"session_id=([0-9a-f]{32})", body) + assert match is not None, f"expected the endpoint event first, got {message!r}" + self.session_id = match.group(1).decode() + self.ready.set() + else: + await self._body_send.send(body) + + +async def _no_receive() -> Message: + raise NotImplementedError + + +async def _no_send(message: Message) -> None: + raise NotImplementedError diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a..ba7554796 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -8,11 +8,13 @@ import anyio import httpx import pytest -from starlette.types import Message +from starlette.types import Message, Scope from mcp import Client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext, streamable_http_manager +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @@ -413,3 +415,166 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def _user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _request_scope( + *, session_id: str | None = None, user: AuthenticatedUser | None = None, method: str = "POST" +) -> Scope: + """Build an ASGI scope for a request to the MCP endpoint.""" + headers = [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + scope: Scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": headers, + } + if user is not None: + scope["user"] = user + return scope + + +async def _open_session(manager: StreamableHTTPSessionManager, user: AuthenticatedUser | None) -> str: + """Create a new session as `user` and return its session ID.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(_request_scope(user=user), mock_receive, mock_send) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + headers = dict(response_start.get("headers", [])) + return headers[MCP_SESSION_ID_HEADER.encode()].decode() + + +async def _request_session( + manager: StreamableHTTPSessionManager, session_id: str, user: AuthenticatedUser | None, method: str = "POST" +) -> int: + """Send a request for an existing session as `user` and return the response status.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request( + _request_scope(session_id=session_id, user=user, method=method), mock_receive, mock_send + ) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + return response_start["status"] + + +@pytest.fixture +async def manager_with_live_session(): + """A running manager around a real `Server`. Sessions remain registered until + `manager.run()` exits because `Server.run` blocks waiting for an initialize message.""" + manager = StreamableHTTPSessionManager(app=Server("test-session-credentials")) + async with manager.run(): + yield manager + + +@pytest.mark.anyio +async def test_session_accepts_requests_from_the_credential_that_created_it( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Requests presenting the same credential as the one that created the session are served.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + status = await _request_session(manager, session_id, _user("client-a")) + + # The request passes the manager's credential check and reaches the + # session's transport, instead of being answered with 404 by the manager. + assert status != 404 + + +@pytest.mark.anyio +@pytest.mark.parametrize("method", ["POST", "GET", "DELETE"]) +async def test_session_rejects_requests_from_a_different_credential( + manager_with_live_session: StreamableHTTPSessionManager, method: str +) -> None: + """A session created by one credential cannot be used with another credential, whatever the method.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, _user("client-b"), method) == 404 + # The session is still registered and still serves its creator. + assert await _request_session(manager, session_id, _user("client-a")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_from_a_different_subject_of_the_same_client( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Two end-users that share an OAuth client cannot use each other's sessions.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a", subject="alice")) + + assert await _request_session(manager, session_id, _user("client-a", subject="bob")) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject=None)) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_with_the_same_subject_from_a_different_issuer( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A subject is unique only per issuer, so a colliding subject from a different issuer is not the same principal.""" + manager = manager_with_live_session + creator = _user("client-a", subject="alice", issuer="https://issuer.one") + session_id = await _open_session(manager, creator) + + other_issuer = _user("client-a", subject="alice", issuer="https://issuer.two") + assert await _request_session(manager, session_id, other_issuer) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) == 404 + assert await _request_session(manager, session_id, creator) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_unauthenticated_requests_for_an_authenticated_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created with a credential cannot be used without one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, None) == 404 + + +@pytest.mark.anyio +async def test_session_rejects_authenticated_requests_for_an_anonymous_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created without a credential cannot be used with one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, _user("client-a")) == 404 + + +@pytest.mark.anyio +async def test_anonymous_session_accepts_anonymous_requests( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Servers without authentication keep working: no credential on either side.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, None) != 404 diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 000000000..be28980b5 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,88 @@ +"""Tests for the transport-security request validation middleware.""" + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def _request(host: str | None, origin: str | None, content_type: str | None = "application/json") -> Request: + headers: list[tuple[bytes, bytes]] = [] + if content_type is not None: + headers.append((b"content-type", content_type.encode())) + if host is not None: + headers.append((b"host", host.encode())) + if origin is not None: + headers.append((b"origin", origin.encode())) + return Request({"type": "http", "method": "GET", "headers": headers}) + + +SETTINGS = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["good.example", "wild.example:*"], + allowed_origins=["http://good.example", "http://wild.example:*"], +) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("host", "origin", "expected"), + [ + pytest.param(None, None, 421, id="missing-host"), + pytest.param("evil.example", None, 421, id="host-no-match"), + pytest.param("evil.example:9000", None, 421, id="host-wildcard-base-mismatch"), + pytest.param("good.example", None, None, id="host-exact-no-origin"), + pytest.param("wild.example:9000", None, None, id="host-wildcard-match"), + pytest.param("good.example", "http://evil.example", 403, id="origin-no-match"), + pytest.param("good.example", "http://evil.example:9000", 403, id="origin-wildcard-base-mismatch"), + pytest.param("good.example", "http://good.example", None, id="origin-exact"), + pytest.param("good.example", "http://wild.example:9000", None, id="origin-wildcard-match"), + ], +) +async def test_validate_request_checks_host_then_origin( + host: str | None, origin: str | None, expected: int | None +) -> None: + """Host is checked first, then Origin; exact and wildcard-port allowlist entries are honoured.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request(host, origin)) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_skips_host_and_origin_when_protection_is_disabled() -> None: + """With DNS-rebinding protection off, any Host/Origin is accepted.""" + middleware = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +async def test_validate_request_defaults_to_protection_disabled() -> None: + """Constructing the middleware without settings leaves DNS-rebinding protection off.""" + middleware = TransportSecurityMiddleware() + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("content_type", "expected"), + [ + pytest.param("application/json", None, id="json"), + pytest.param("application/json; charset=utf-8", None, id="json-with-charset"), + pytest.param("APPLICATION/JSON", None, id="case-insensitive"), + pytest.param("text/plain", 400, id="wrong-type"), + pytest.param(None, 400, id="missing"), + ], +) +async def test_validate_request_checks_content_type_on_post(content_type: str | None, expected: int | None) -> None: + """POST requests must carry an application/json Content-Type, regardless of DNS-rebinding settings.""" + middleware = TransportSecurityMiddleware() + response = await middleware.validate_request(_request("any", None, content_type=content_type), is_post=True) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_ignores_content_type_on_get() -> None: + """Content-Type is only enforced for POST requests.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request("good.example", None, content_type="text/plain")) + assert response is None