Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/mcp/server/auth/middleware/bearer_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from typing import Any
from typing import Any, TypedDict

from pydantic import AnyHttpUrl
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
Expand All @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken):
self.scopes = auth_info.scopes


class AuthorizationContext(TypedDict):
client_id: str
issuer: str | None
subject: str | None


def authorization_context(user: AuthenticatedUser) -> AuthorizationContext:
"""Identify the principal `user` represents, for transports to compare
against the principal that created a session. Components the token
verifier does not supply are `None`, so the comparison degrades to the
remaining components.

See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for
a verifier that populates `subject` and `claims` from an introspection
response."""
token = user.access_token
issuer = (token.claims or {}).get("iss")
return AuthorizationContext(
client_id=token.client_id,
issuer=str(issuer) if issuer is not None else None,
subject=token.subject,
)


class BearerAuthBackend(AuthenticationBackend):
"""
Authentication backend that validates Bearer tokens using a TokenVerifier.
Expand Down
65 changes: 45 additions & 20 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
from mcp.server.transport_security import (
TransportSecurityMiddleware,
TransportSecuritySettings,
Expand All @@ -75,6 +76,9 @@ class SseServerTransport:

_endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
# Identity of the credential that created each session; requests for a
# session must present the same credential.
_session_owners: dict[UUID, AuthorizationContext]
_security: TransportSecurityMiddleware

def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
Expand Down Expand Up @@ -115,6 +119,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |

self._endpoint = endpoint
self._read_stream_writers = {}
self._session_owners = {}
self._security = TransportSecurityMiddleware(security_settings)
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

Expand Down Expand Up @@ -142,6 +147,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

session_id = uuid4()
user = scope.get("user")
if isinstance(user, AuthenticatedUser):
self._session_owners[session_id] = authorization_context(user)
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

Expand Down Expand Up @@ -177,26 +185,34 @@ async def sse_writer():
}
)

async with anyio.create_task_group() as tg:

async def response_wrapper(scope: Scope, receive: Receive, send: Send):
"""
The EventSourceResponse returning signals a client close / disconnect.
In this case we close our side of the streams to signal the client that
the connection has been closed.
"""
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
scope, receive, send
)
await read_stream_writer.aclose()
await write_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}")

logger.debug("Starting SSE response task")
tg.start_soon(response_wrapper, scope, receive, send)

logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
try:
async with anyio.create_task_group() as tg:

async def response_wrapper(scope: Scope, receive: Receive, send: Send):
"""
The EventSourceResponse returning signals a client close / disconnect.
In this case we close our side of the streams to signal the client that
the connection has been closed.
"""
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
scope, receive, send
)
await read_stream_writer.aclose()
await write_stream_reader.aclose()
await sse_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}")

logger.debug("Starting SSE response task")
tg.start_soon(response_wrapper, scope, receive, send)

logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
finally:
# The connection is gone: stop routing messages to this session
# and drop its entries so they do not accumulate for the lifetime
# of the transport.
self._read_stream_writers.pop(session_id, None)
self._session_owners.pop(session_id, None)

async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
logger.debug("Handling POST message")
Expand Down Expand Up @@ -227,6 +243,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)

user = scope.get("user")
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
if requestor != self._session_owners.get(session_id):
# A session can only be used with the credential that created it.
# Respond exactly as if the session did not exist.
logger.warning("Rejecting message for session %s: credential does not match", session_id)
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)

body = await request.body()
logger.debug(f"Received JSON: {body}")

Expand Down
49 changes: 34 additions & 15 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import contextlib
import logging
from collections.abc import AsyncIterator
from http import HTTPStatus
from typing import Any
from uuid import uuid4

Expand All @@ -15,6 +14,7 @@
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
from mcp.server.lowlevel.server import Server as MCPServer
from mcp.server.streamable_http import (
MCP_SESSION_ID_HEADER,
Expand Down Expand Up @@ -88,6 +88,9 @@ def __init__(
# Session tracking (only used if not stateless)
self._session_creation_lock = anyio.Lock()
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
# Identity of the credential that created each session; requests for a
# session must present the same credential.
self._session_owners: dict[str, AuthorizationContext] = {}

# The task group will be set during lifespan
self._task_group = None
Expand Down Expand Up @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
self._task_group = None
# Clear any remaining server instances
self._server_instances.clear()
self._session_owners.clear()

async def handle_request(
self,
Expand Down Expand Up @@ -227,12 +231,32 @@ async def _handle_stateful_request(
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)

user = scope.get("user")
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None

# Existing session case
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
transport = self._server_instances[request_mcp_session_id]
if requestor != self._session_owners.get(request_mcp_session_id):
# A session can only be used with the credential that created
# it. Respond exactly as if the session did not exist.
logger.warning(
"Rejecting request for session %s: credential does not match the one that created the session",
request_mcp_session_id[:64],
)
body = JSONRPCError(
jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found")
)
response = Response(
body.model_dump_json(by_alias=True, exclude_none=True),
status_code=404,
media_type="application/json",
)
await response(scope, receive, send)
return
logger.debug("Session already exists, handling request directly")
# Push back idle deadline on activity
if transport.idle_scope is not None and self.session_idle_timeout is not None:
if transport.idle_scope is not None and self.session_idle_timeout is not None: # pragma: no cover
transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout
await transport.handle_request(scope, receive, send)
return
Expand All @@ -251,6 +275,8 @@ async def _handle_stateful_request(
)

assert http_transport.mcp_session_id is not None
if requestor is not None:
self._session_owners[http_transport.mcp_session_id] = requestor
self._server_instances[http_transport.mcp_session_id] = http_transport
logger.info(f"Created new transport with session ID: {new_session_id}")

Expand Down Expand Up @@ -281,6 +307,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
assert http_transport.mcp_session_id is not None
logger.info(f"Session {http_transport.mcp_session_id} idle timeout")
self._server_instances.pop(http_transport.mcp_session_id, None)
self._session_owners.pop(http_transport.mcp_session_id, None)
await http_transport.terminate()
except Exception:
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
Expand All @@ -296,6 +323,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
"active instances."
)
del self._server_instances[http_transport.mcp_session_id]
self._session_owners.pop(http_transport.mcp_session_id, None)

# Assert task group is not None for type checking
assert self._task_group is not None
Expand All @@ -306,19 +334,10 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
await http_transport.handle_request(scope, receive, send)
else:
# Unknown or expired session ID - return 404 per MCP spec
# TODO: Align error code once spec clarifies
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
error_response = JSONRPCError(
jsonrpc="2.0",
id="server-error",
error=ErrorData(
code=INVALID_REQUEST,
message="Session not found",
),
body = JSONRPCError(
jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found")
)
response = Response(
content=error_response.model_dump_json(by_alias=True, exclude_none=True),
status_code=HTTPStatus.NOT_FOUND,
media_type="application/json",
body.model_dump_json(by_alias=True, exclude_none=True), status_code=404, media_type="application/json"
)
await response(scope, receive, send)
Loading
Loading