From 0bb9f25024aad3928743810ebdbf6f9e828b0e6c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 22 May 2026 22:15:18 +0000 Subject: [PATCH] Scope experimental tasks to the session that created them Task IDs generated by run_task() now embed an opaque per-session marker, and the default handlers registered by enable_tasks() use it to restrict each session to its own tasks: tasks/get, tasks/result, and tasks/cancel respond with "task not found" for another session's task, and tasks/list returns only the requesting session's tasks. The default tasks/list handler no longer exposes the store's pagination cursor, which is derived from the unfiltered listing and could identify another session's task. Tasks whose IDs carry no marker (explicitly chosen IDs, tasks created directly through a TaskStore, or tasks on stateless servers) remain usable by any requestor that presents the exact ID, but are no longer included in tasks/list responses. Passing an explicit task_id to run_task() is deprecated because such tasks cannot be associated with the session that created them. The TaskStore interface and the wire protocol are unchanged; the marker travels inside the task ID string. --- docs/experimental/tasks-server.md | 23 ++ src/mcp/server/experimental/__init__.py | 1 + .../server/experimental/request_context.py | 50 ++- .../server/experimental/session_features.py | 6 + .../experimental/task_result_handler.py | 8 +- src/mcp/server/experimental/task_scope.py | 75 ++++ src/mcp/server/experimental/task_support.py | 13 +- src/mcp/server/lowlevel/experimental.py | 60 +++- src/mcp/server/lowlevel/server.py | 2 +- .../experimental/tasks/server/test_server.py | 7 +- .../tasks/server/test_task_scope.py | 150 ++++++++ .../tasks/server/test_task_visibility.py | 321 ++++++++++++++++++ .../tasks/test_request_context.py | 29 ++ 13 files changed, 731 insertions(+), 14 deletions(-) create mode 100644 src/mcp/server/experimental/task_scope.py create mode 100644 tests/experimental/tasks/server/test_task_scope.py create mode 100644 tests/experimental/tasks/server/test_task_visibility.py diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md index 761dc5de5c..c6b94814cd 100644 --- a/docs/experimental/tasks-server.md +++ b/docs/experimental/tasks-server.md @@ -53,6 +53,29 @@ That's it. `enable_tasks()` automatically: - Registers handlers for `tasks/get`, `tasks/result`, `tasks/list`, `tasks/cancel` - Updates server capabilities +## Task Visibility + +Task IDs generated by `run_task()` embed an opaque marker identifying the session that +created the task, and the default handlers use it to restrict each session to its own +tasks: `tasks/get`, `tasks/result`, and `tasks/cancel` respond with "task not found" for +another session's task, and `tasks/list` returns only the requesting session's tasks. A +client that reconnects gets a new session and can no longer reach tasks it created on the +previous one. + +A task ID has no session marker when it was passed to `run_task()` explicitly, when the +task was created directly through the `TaskStore`, or when the server runs in stateless +mode (each request gets a fresh session, so tasks must remain reachable across requests). +Such tasks are accessible to any requestor that presents the exact task ID, and are never +included in `tasks/list` responses because the server cannot tell which session they +belong to. Treat these task IDs as capabilities: generate them with enough entropy that +they cannot be guessed, share them only with the intended recipient, and prefer short +TTLs. Passing an explicit `task_id` to `run_task()` is deprecated for this reason. + +To scope tasks to something other than the session — for example a user identity from your +authorization layer — register your own handlers with `@server.experimental.get_task()`, +`@server.experimental.get_task_result()`, `@server.experimental.list_tasks()`, and +`@server.experimental.cancel_task()` instead of relying on the defaults. + ## Tool Declaration Tools declare task support via the `execution.taskSupport` field: diff --git a/src/mcp/server/experimental/__init__.py b/src/mcp/server/experimental/__init__.py index 824bb8b8be..91c6dcf3e8 100644 --- a/src/mcp/server/experimental/__init__.py +++ b/src/mcp/server/experimental/__init__.py @@ -8,4 +8,5 @@ - mcp.server.experimental.task_support.TaskSupport - mcp.server.experimental.task_result_handler.TaskResultHandler - mcp.server.experimental.request_context.Experimental +- mcp.server.experimental.task_scope (session scoping of task IDs) """ diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 78e75beb6a..0d69836355 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -7,11 +7,15 @@ WARNING: These APIs are experimental and may change without notice. """ +import warnings from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import Any +from typing import Any, overload + +from typing_extensions import deprecated from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.experimental.task_scope import scoped_task_id from mcp.server.experimental.task_support import TaskSupport from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError @@ -29,6 +33,14 @@ Tool, ) +EXPLICIT_TASK_ID_DEPRECATION = ( + "Passing an explicit task_id to run_task is deprecated. A task created with an " + "explicit ID is not associated with the session that created it: any requestor " + "that presents the ID can read its status and result or cancel it, and it never " + "appears in tasks/list. Omit task_id to let the SDK generate an ID associated " + "with the creating session." +) + @dataclass class Experimental: @@ -143,6 +155,25 @@ def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: return False return True + @overload + async def run_task( + self, + work: Callable[[ServerTaskContext], Awaitable[Result]], + *, + task_id: None = None, + model_immediate_response: str | None = None, + ) -> CreateTaskResult: ... + + @overload + @deprecated(EXPLICIT_TASK_ID_DEPRECATION) + async def run_task( + self, + work: Callable[[ServerTaskContext], Awaitable[Result]], + *, + task_id: str, + model_immediate_response: str | None = None, + ) -> CreateTaskResult: ... + async def run_task( self, work: Callable[[ServerTaskContext], Awaitable[Result]], @@ -167,9 +198,17 @@ async def run_task( When work() returns a Result, the task is auto-completed with that result. If work() raises an exception, the task is auto-failed. + Generated task IDs embed the session's task scope so that the default + task handlers only serve the task to the session that created it. An + explicitly provided `task_id` is used verbatim and is not associated + with the session, so any session can access it through the default + handlers; passing one is deprecated for that reason. + Args: work: Async function that does the actual work - task_id: Optional task ID (generated if not provided) + task_id: Deprecated. Optional task ID, used verbatim and not + associated with the creating session. Omit it to let the SDK + generate one. model_immediate_response: Optional string to include in _meta as io.modelcontextprotocol/model-immediate-response @@ -196,6 +235,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: WARNING: This API is experimental and may change without notice. """ + if task_id is not None: + warnings.warn(EXPLICIT_TASK_ID_DEPRECATION, DeprecationWarning, stacklevel=2) if self._task_support is None: raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.") if self._session is None: @@ -210,6 +251,11 @@ async def work(task: ServerTaskContext) -> CallToolResult: # Access task_group via TaskSupport - raises if not in run() context task_group = support.task_group + if task_id is None: + session_scope = self._session.experimental.task_session_scope + if session_scope is not None: + task_id = scoped_task_id(session_scope) + task = await support.store.create_task(self.task_metadata, task_id) task_ctx = ServerTaskContext( diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py index 4842da5175..c118537fa2 100644 --- a/src/mcp/server/experimental/session_features.py +++ b/src/mcp/server/experimental/session_features.py @@ -40,6 +40,12 @@ class ExperimentalServerSessionFeatures: def __init__(self, session: "ServerSession") -> None: self._session = session + # Opaque marker identifying this session for task scoping. Assigned by + # TaskSupport.configure_session(). Task IDs generated by run_task() + # embed it so the default task handlers can restrict task access to + # the session that created the task. None means tasks created on this + # session are not associated with it (e.g. stateless servers). + self.task_session_scope: str | None = None async def get_task(self, task_id: str) -> types.GetTaskResult: """ diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 0b869216e8..1cf7f69749 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -46,6 +46,11 @@ class TaskResultHandler: 4. Blocks until task reaches terminal state 5. Returns the final result + Prefer `server.experimental.enable_tasks()`, whose default tasks/result + handler wraps `handle()` and only serves tasks created by the requesting + session. A custom handler that calls `handle()` directly is responsible + for deciding which requestors may access which tasks. + Usage: # Create handler with store and queue handler = TaskResultHandler(task_store, message_queue) @@ -55,9 +60,6 @@ class TaskResultHandler: async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: ctx = server.request_context return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) """ def __init__( diff --git a/src/mcp/server/experimental/task_scope.py b/src/mcp/server/experimental/task_scope.py new file mode 100644 index 0000000000..c33cf55725 --- /dev/null +++ b/src/mcp/server/experimental/task_scope.py @@ -0,0 +1,75 @@ +""" +Session scoping for experimental task identifiers. + +Task IDs generated by `run_task()` embed an opaque, per-session marker (the +"session scope") so that the default task handlers can tell which session +created a task. The default handlers for tasks/get, tasks/result, tasks/list, +and tasks/cancel only operate on tasks created by the requesting session. + +Task IDs without a session scope (explicitly provided IDs, IDs created +directly through a TaskStore, or IDs created in stateless mode) have no known +creator. They can be used with tasks/get, tasks/result, and tasks/cancel from +any session - possession of the ID is what grants access - but they are never +included in tasks/list responses. + +WARNING: These APIs are experimental and may change without notice. +""" + +import re +from uuid import uuid4 + +__all__ = [ + "new_session_scope", + "scoped_task_id", + "session_scope_of", + "task_in_session_scope", + "task_listable_in_session_scope", +] + +# A scoped task ID has the form "<32 hex chars>:". Both halves must +# match exactly so that explicitly chosen task IDs are never mistaken for +# scoped ones. \Z rather than $ so a trailing newline cannot match. +_SCOPED_TASK_ID = re.compile( + r"\A(?P[0-9a-f]{32}):" + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\Z" +) + + +def new_session_scope() -> str: + """Create a new opaque session scope token.""" + return uuid4().hex + + +def scoped_task_id(session_scope: str) -> str: + """Generate a task ID associated with the given session scope.""" + return f"{session_scope}:{uuid4()}" + + +def session_scope_of(task_id: str) -> str | None: + """Return the session scope embedded in a task ID, or None if it has none.""" + match = _SCOPED_TASK_ID.match(task_id) + return match.group("scope") if match else None + + +def task_in_session_scope(task_id: str, session_scope: str | None) -> bool: + """Whether a task may be used by a requestor with the given session scope. + + Used by tasks/get, tasks/result, and tasks/cancel. A task whose ID carries + no session scope has no known creator, so possession of the ID is what + grants access to it: it can be used from any session. + """ + embedded = session_scope_of(task_id) + return embedded is None or embedded == session_scope + + +def task_listable_in_session_scope(task_id: str, session_scope: str | None) -> bool: + """Whether a task may be included in a tasks/list response for the given session scope. + + Used by tasks/list. Listing is stricter than access by ID: a task is only + listed to the session that created it. Tasks with no session scope are + never listed because they have no known creator, and requestors with no + session scope are never shown any tasks because the server cannot tell + them apart. + """ + embedded = session_scope_of(task_id) + return embedded is not None and embedded == session_scope diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index dbb2ed6d2b..8e91faf73b 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -13,6 +13,7 @@ from anyio.abc import TaskGroup from mcp.server.experimental.task_result_handler import TaskResultHandler +from mcp.server.experimental.task_scope import new_session_scope from mcp.server.session import ServerSession from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue @@ -83,7 +84,7 @@ async def run(self) -> AsyncIterator[None]: finally: self._task_group = None - def configure_session(self, session: ServerSession) -> None: + def configure_session(self, session: ServerSession, *, stateless: bool = False) -> None: """ Configure a session for task support. @@ -91,12 +92,22 @@ def configure_session(self, session: ServerSession) -> None: responses to queued requests (elicitation, sampling) are routed back to the waiting resolvers. + It also assigns the session a task session scope. Task IDs generated + by `run_task()` embed this scope, and the default task handlers only + operate on tasks created by the requesting session. Stateless sessions + are not assigned a scope: each request runs on a fresh session, so a + task created by one request could never be retrieved by a later one if + tasks were bound to the session that created them. + Called automatically by Server.run() for each new session. Args: session: The session to configure + stateless: Whether the session belongs to a stateless server run """ session.add_response_router(self.handler) + if not stateless and session.experimental.task_session_scope is None: + session.experimental.task_session_scope = new_session_scope() @classmethod def in_memory(cls) -> "TaskSupport": diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 42353e4ea0..737d6bb2cd 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -9,6 +9,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING +from mcp.server.experimental.task_scope import task_in_session_scope, task_listable_in_session_scope from mcp.server.experimental.task_support import TaskSupport from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import McpError @@ -31,6 +32,7 @@ ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, + Task, TasksCallCapability, TasksCancelCapability, TasksListCapability, @@ -125,8 +127,38 @@ def enable_tasks( return self._task_support + def _requestor_session_scope(self) -> str | None: + """Return the task session scope of the session making the current request.""" + return self._server.request_context.session.experimental.task_session_scope + + def _require_task_in_requestor_scope(self, task_id: str) -> None: + """Reject task IDs that belong to a different session. + + Task IDs generated by `run_task()` embed the creating session's + scope. The default handlers treat a task created by another session + exactly like a task that does not exist, so a requestor cannot tell + whether such a task exists. Task IDs without an embedded scope are + accepted from any session. + + Raises: + McpError: With INVALID_PARAMS if the task belongs to another session. + """ + if not task_in_session_scope(task_id, self._requestor_session_scope()): + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {task_id}", + ) + ) + def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" + """Register default handlers for task operations. + + Each default handler only operates on tasks created by the requesting + session (see `_require_task_in_requestor_scope`), and tasks/list only + returns the requesting session's own tasks (see + `task_listable_in_session_scope`). + """ assert self._task_support is not None support = self._task_support @@ -134,6 +166,7 @@ def _register_default_task_handlers(self) -> None: if GetTaskRequest not in self._request_handlers: async def _default_get_task(req: GetTaskRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) task = await support.store.get_task(req.params.taskId) if task is None: raise McpError( @@ -160,6 +193,7 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: if GetTaskPayloadRequest not in self._request_handlers: async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) ctx = self._server.request_context result = await support.handler.handle(req, ctx.session, ctx.request_id) return ServerResult(result) @@ -170,9 +204,26 @@ async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: if ListTasksRequest not in self._request_handlers: async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) - return ServerResult(ListTasksResult(tasks=tasks, nextCursor=next_cursor)) + requestor_scope = self._requestor_session_scope() + if requestor_scope is None: + # The server cannot tell this requestor apart from any + # other, so there are no tasks it can be shown. + return ServerResult(ListTasksResult(tasks=[])) + # Return every task that belongs to the requesting session in + # a single page. The store's pagination cursor is never sent + # to the requestor: it is derived from the unfiltered listing, + # so it could identify a task belonging to a different + # session. For the same reason the request's cursor is not + # forwarded to the store. + own_tasks: list[Task] = [] + cursor: str | None = None + while True: + page, cursor = await support.store.list_tasks(cursor) + own_tasks.extend( + task for task in page if task_listable_in_session_scope(task.taskId, requestor_scope) + ) + if cursor is None: + return ServerResult(ListTasksResult(tasks=own_tasks)) self._request_handlers[ListTasksRequest] = _default_list_tasks @@ -180,6 +231,7 @@ async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: if CancelTaskRequest not in self._request_handlers: async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) result = await cancel_task(support.store, req.params.taskId) return ServerResult(result) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 2dd1a8277a..7d925de32b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -667,7 +667,7 @@ async def run( # Configure task support for this session if enabled task_support = self._experimental_handlers.task_support if self._experimental_handlers else None if task_support is not None: - task_support.configure_session(session) + task_support.configure_session(session, stateless=stateless) await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 7209ed412a..64a1dabb04 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -506,13 +506,14 @@ async def run_server() -> None: # Create a task directly in the store for testing task = await store.create_task(TaskMetadata(ttl=60000)) - # Test list_tasks (default handler) + # Test list_tasks (default handler). Tasks created directly in the + # store have no session scope, so they are reachable by ID but not + # included in tasks/list (see test_task_scope.py). list_result = await client_session.send_request( ClientRequest(ListTasksRequest()), ListTasksResult, ) - assert len(list_result.tasks) == 1 - assert list_result.tasks[0].taskId == task.taskId + assert list_result.tasks == [] # Test get_task (default handler - found) get_result = await client_session.send_request( diff --git a/tests/experimental/tasks/server/test_task_scope.py b/tests/experimental/tasks/server/test_task_scope.py new file mode 100644 index 0000000000..c13b728a86 --- /dev/null +++ b/tests/experimental/tasks/server/test_task_scope.py @@ -0,0 +1,150 @@ +"""Unit tests for the task session-scope helpers. + +A session scope is an opaque marker assigned to each session by +TaskSupport.configure_session(). Task IDs generated by run_task() embed it so +the default task handlers can tell which session created a task. See +test_task_visibility.py for the end-to-end behaviour these helpers produce. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import anyio +import pytest + +from mcp.server import Server +from mcp.server.experimental.task_scope import ( + new_session_scope, + scoped_task_id, + session_scope_of, + task_in_session_scope, + task_listable_in_session_scope, +) +from mcp.server.experimental.task_support import TaskSupport +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage + + +def test_new_session_scope_is_unique() -> None: + assert new_session_scope() != new_session_scope() + + +def test_scoped_task_id_round_trips_its_scope() -> None: + scope = new_session_scope() + + task_id = scoped_task_id(scope) + + assert session_scope_of(task_id) == scope + + +def test_scoped_task_ids_are_unique_within_a_scope() -> None: + scope = new_session_scope() + + assert scoped_task_id(scope) != scoped_task_id(scope) + + +@pytest.mark.parametrize( + "task_id", + [ + "plain-task-id", + "550e8400-e29b-41d4-a716-446655440000", # bare uuid4 + "", + # Right shape but the scope half is not 32 hex chars. + "not-a-scope:550e8400-e29b-41d4-a716-446655440000", + # Right scope half but the suffix is not a uuid4. + "0123456789abcdef0123456789abcdef:not-a-uuid", + # Uppercase hex is not produced by new_session_scope(). + "0123456789ABCDEF0123456789ABCDEF:550e8400-e29b-41d4-a716-446655440000", + ], +) +def test_session_scope_of_returns_none_for_unscoped_ids(task_id: str) -> None: + assert session_scope_of(task_id) is None + + +def test_a_scoped_task_is_usable_only_from_the_scope_that_created_it() -> None: + scope = new_session_scope() + task_id = scoped_task_id(scope) + + assert task_in_session_scope(task_id, scope) is True + assert task_in_session_scope(task_id, new_session_scope()) is False + assert task_in_session_scope(task_id, None) is False + + +def test_an_unscoped_task_is_usable_from_any_scope() -> None: + assert task_in_session_scope("plain-task-id", new_session_scope()) is True + assert task_in_session_scope("plain-task-id", None) is True + + +def test_a_scoped_task_is_listable_only_in_the_scope_that_created_it() -> None: + scope = new_session_scope() + task_id = scoped_task_id(scope) + + assert task_listable_in_session_scope(task_id, scope) is True + assert task_listable_in_session_scope(task_id, new_session_scope()) is False + assert task_listable_in_session_scope(task_id, None) is False + + +def test_an_unscoped_task_is_never_listable() -> None: + assert task_listable_in_session_scope("plain-task-id", new_session_scope()) is False + assert task_listable_in_session_scope("plain-task-id", None) is False + + +@asynccontextmanager +async def _make_session() -> AsyncIterator[ServerSession]: + """Create a ServerSession suitable for inspecting configure_session().""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + options = InitializationOptions( + server_name="test", + server_version="0", + capabilities=Server("test").get_capabilities(NotificationOptions(), {}), + ) + async with ( + server_to_client_receive, + client_to_server_send, + ServerSession(client_to_server_receive, server_to_client_send, options) as session, + ): + yield session + + +@pytest.mark.anyio +async def test_configure_session_assigns_a_scope() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + assert session.experimental.task_session_scope is None + + support.configure_session(session) + + assert session.experimental.task_session_scope is not None + + +@pytest.mark.anyio +async def test_configure_session_assigns_distinct_scopes_per_session() -> None: + support = TaskSupport.in_memory() + async with _make_session() as first, _make_session() as second: + support.configure_session(first) + support.configure_session(second) + + assert first.experimental.task_session_scope != second.experimental.task_session_scope + + +@pytest.mark.anyio +async def test_configure_session_is_idempotent() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + support.configure_session(session) + scope = session.experimental.task_session_scope + support.configure_session(session) + + assert session.experimental.task_session_scope == scope + + +@pytest.mark.anyio +async def test_configure_session_assigns_no_scope_to_stateless_sessions() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + support.configure_session(session, stateless=True) + + assert session.experimental.task_session_scope is None diff --git a/tests/experimental/tasks/server/test_task_visibility.py b/tests/experimental/tasks/server/test_task_visibility.py new file mode 100644 index 0000000000..2ee16399ea --- /dev/null +++ b/tests/experimental/tasks/server/test_task_visibility.py @@ -0,0 +1,321 @@ +"""End-to-end tests for which clients can see and control a task. + +Every test runs a real server and one or more in-memory client sessions. A +task started with run_task() belongs to the client session that started it: +that session can poll it, list it, and cancel it, while every other session +is told the task does not exist. Tasks whose IDs carry no session marker +(explicitly chosen IDs, or tasks on stateless servers) are usable by any +session that knows the ID, but are never listed. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import AsyncExitStack +from typing import Any + +import anyio +import pytest +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolResult, + CreateTaskResult, + ListTasksResult, + TextContent, + Tool, + ToolExecution, +) + +# The `connect` fixture: each call opens a new client session against the test server. +Connect = Callable[..., Awaitable[ClientSession]] + +# Enough tasks that the bundled in-memory store needs more than one page (of 10) +# to list them, so listings that span store pages are exercised. +MORE_TASKS_THAN_ONE_STORE_PAGE = 11 + + +def build_task_server(store: TaskStore | None = None) -> Server: + """Build a server exposing three task tools. + + - "greet" finishes immediately and returns a greeting. + - "long_running_job" keeps running until the server shuts down. + - "nightly_export" is a singleton job: every invocation uses the + explicitly chosen task ID "the-nightly-export". + """ + server = Server("task-visibility-test-server") + server.experimental.enable_tasks(store=store) + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name=name, + description=name, + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + for name in ("greet", "long_running_job", "nightly_export") + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + async def greet(task: ServerTaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Hello, {arguments['name']}!")]) + + async def long_running_job(task: ServerTaskContext) -> CallToolResult: + await anyio.sleep_forever() + raise AssertionError("unreachable") # pragma: no cover + + async def nightly_export(task: ServerTaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text="exported")]) + + run_task = server.request_context.experimental.run_task + if name == "nightly_export": + return await run_task(nightly_export, task_id="the-nightly-export") + return await run_task(greet if name == "greet" else long_running_job) + + return server + + +async def open_client( + server: Server, task_group: TaskGroup, stack: AsyncExitStack, *, stateless: bool = False +) -> ClientSession: + """Connect a new client session to `server` over in-memory streams.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + stateless=stateless, + ) + + task_group.start_soon(run_server) + client = await stack.enter_async_context(ClientSession(server_to_client_receive, client_to_server_send)) + await client.initialize() + return client + + +@pytest.fixture +def task_server() -> Server: + return build_task_server() + + +@pytest.fixture +async def connect(task_server: Server) -> AsyncIterator[Connect]: + """A factory that opens a new client session against the test server on each call.""" + async with anyio.create_task_group() as task_group, AsyncExitStack() as stack: + + async def _connect(*, stateless: bool = False) -> ClientSession: + return await open_client(task_server, task_group, stack, stateless=stateless) + + yield _connect + task_group.cancel_scope.cancel() + + +async def start_task(client: ClientSession, tool: str = "long_running_job", **arguments: Any) -> str: + """Start `tool` as a task and return the new task's ID.""" + result = await client.experimental.call_tool_as_task(tool, arguments) + return result.task.taskId + + +async def wait_until_finished(client: ClientSession, task_id: str) -> None: + """Poll the task until it reaches a terminal status.""" + with anyio.fail_after(5): + async for _ in client.experimental.poll_task(task_id): + pass + + +async def listed_task_ids(client: ClientSession) -> list[str]: + """Return the IDs of every task the server lists for this client.""" + return [task.taskId for task in (await client.experimental.list_tasks()).tasks] + + +# --- What the client that started a task can do with it --- + + +@pytest.mark.anyio +async def test_a_client_can_poll_its_own_task_to_completion_and_read_the_result(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client, "greet", name="Ada") + await wait_until_finished(client, task_id) + + result = await client.experimental.get_task_result(task_id, CallToolResult) + + assert result.content == [TextContent(type="text", text="Hello, Ada!")] + + +@pytest.mark.anyio +async def test_a_client_sees_its_own_task_when_listing_tasks(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client) + + listed = await listed_task_ids(client) + + assert listed == [task_id] + + +@pytest.mark.anyio +async def test_a_client_can_cancel_its_own_task(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client) + + cancelled = await client.experimental.cancel_task(task_id) + + assert cancelled.status == "cancelled" + + +# --- What a client cannot do with a task started by another client --- + + +@pytest.mark.anyio +async def test_a_client_cannot_get_the_status_of_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task(task_id) + + +@pytest.mark.anyio +async def test_a_client_cannot_get_the_result_of_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator, "greet", name="Ada") + await wait_until_finished(creator, task_id) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task_result(task_id, CallToolResult) + + +@pytest.mark.anyio +async def test_a_client_cannot_cancel_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.cancel_task(task_id) + + # The task is unaffected. + assert (await creator.experimental.get_task(task_id)).status == "working" + + +@pytest.mark.anyio +async def test_a_client_does_not_see_another_clients_task_when_listing_tasks(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + await start_task(creator) + + listed = await listed_task_ids(other_client) + + assert listed == [] + + +@pytest.mark.anyio +async def test_each_client_lists_only_its_own_tasks(connect: Connect) -> None: + first_client = await connect() + second_client = await connect() + first_task = await start_task(first_client) + second_task = await start_task(second_client) + + assert await listed_task_ids(first_client) == [first_task] + assert await listed_task_ids(second_client) == [second_task] + + +@pytest.mark.anyio +async def test_listing_tasks_reveals_nothing_about_other_clients_tasks_however_many_there_are( + connect: Connect, +) -> None: + """The listing must not identify other clients' tasks through any field, including the pagination cursor.""" + creator = await connect() + other_client = await connect() + for _ in range(MORE_TASKS_THAN_ONE_STORE_PAGE): + await start_task(creator) + + listing = await other_client.experimental.list_tasks() + + assert listing == ListTasksResult(tasks=[], nextCursor=None) + + +@pytest.mark.anyio +async def test_a_client_with_more_than_one_store_page_of_tasks_lists_all_of_them(connect: Connect) -> None: + client = await connect() + started = {await start_task(client) for _ in range(MORE_TASKS_THAN_ONE_STORE_PAGE)} + + listing = await client.experimental.list_tasks() + + assert {task.taskId for task in listing.tasks} == started + assert listing.nextCursor is None + + +# --- Tasks that do not belong to any client session --- + + +@pytest.mark.anyio +# Choosing the task ID instead of letting the SDK generate one is deprecated for +# exactly the behaviour this test demonstrates: the task is not tied to the +# session that created it. +@pytest.mark.filterwarnings("ignore:Passing an explicit task_id") +async def test_a_task_whose_id_was_chosen_by_the_server_is_accessible_to_every_client(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + await wait_until_finished(creator, await start_task(creator, "nightly_export")) + + status = await other_client.experimental.get_task("the-nightly-export") + + assert status.status == "completed" + + +@pytest.mark.anyio +async def test_a_stateless_server_serves_a_task_to_any_session_that_knows_its_id(connect: Connect) -> None: + first_session = await connect(stateless=True) + second_session = await connect(stateless=True) + task_id = await start_task(first_session, "greet", name="Ada") + await wait_until_finished(second_session, task_id) + + result = await second_session.experimental.get_task_result(task_id, CallToolResult) + + assert result.content == [TextContent(type="text", text="Hello, Ada!")] + + +@pytest.mark.anyio +async def test_a_stateless_server_lists_no_tasks(connect: Connect) -> None: + session = await connect(stateless=True) + await start_task(session) + + listed = await listed_task_ids(session) + + assert listed == [] + + +# --- The behaviour does not depend on the bundled in-memory store --- + + +@pytest.mark.anyio +async def test_clients_are_isolated_when_the_server_uses_a_custom_task_store() -> None: + class CustomTaskStore(InMemoryTaskStore): + """A stand-in for a user-provided TaskStore implementation.""" + + server = build_task_server(store=CustomTaskStore()) + + async with anyio.create_task_group() as task_group, AsyncExitStack() as stack: + creator = await open_client(server, task_group, stack) + other_client = await open_client(server, task_group, stack) + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task(task_id) + + assert (await creator.experimental.get_task(task_id)).status == "working" + task_group.cancel_scope.cancel() diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py index 5fa5da81af..e52ec4403a 100644 --- a/tests/experimental/tasks/test_request_context.py +++ b/tests/experimental/tasks/test_request_context.py @@ -3,6 +3,7 @@ import pytest from mcp.server.experimental.request_context import Experimental +from mcp.server.experimental.task_context import ServerTaskContext from mcp.shared.exceptions import McpError from mcp.types import ( METHOD_NOT_FOUND, @@ -11,6 +12,7 @@ TASK_REQUIRED, ClientCapabilities, ClientTasksCapability, + Result, TaskMetadata, Tool, ToolExecution, @@ -164,3 +166,30 @@ def test_can_use_tool_forbidden_without_task_support() -> None: def test_can_use_tool_none_without_task_support() -> None: exp = Experimental(_client_capabilities=ClientCapabilities()) assert exp.can_use_tool(None) is True + + +@pytest.mark.anyio +async def test_run_task_with_an_explicit_task_id_emits_a_deprecation_warning() -> None: + """An explicitly provided task ID is not associated with the creating session, so passing one is deprecated.""" + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + + async def work(task: ServerTaskContext) -> Result: + raise AssertionError("unreachable") # pragma: no cover + + with pytest.warns(DeprecationWarning, match="not associated with the session"): + # Task support is not configured, so the call fails after the + # deprecated argument has been reported. + with pytest.raises(RuntimeError, match="Task support not enabled"): + # The deliberate use of the deprecated overload is the point of this test. + await exp.run_task(work, task_id="explicitly-chosen") # pyright: ignore[reportDeprecated] + + +@pytest.mark.anyio +async def test_run_task_without_a_task_id_does_not_warn() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + + async def work(task: ServerTaskContext) -> Result: + raise AssertionError("unreachable") # pragma: no cover + + with pytest.raises(RuntimeError, match="Task support not enabled"): + await exp.run_task(work)