diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md index 761dc5de5c..c6b94814cd 100644 --- a/docs/experimental/tasks-server.md +++ b/docs/experimental/tasks-server.md @@ -53,6 +53,29 @@ That's it. `enable_tasks()` automatically: - Registers handlers for `tasks/get`, `tasks/result`, `tasks/list`, `tasks/cancel` - Updates server capabilities +## Task Visibility + +Task IDs generated by `run_task()` embed an opaque marker identifying the session that +created the task, and the default handlers use it to restrict each session to its own +tasks: `tasks/get`, `tasks/result`, and `tasks/cancel` respond with "task not found" for +another session's task, and `tasks/list` returns only the requesting session's tasks. A +client that reconnects gets a new session and can no longer reach tasks it created on the +previous one. + +A task ID has no session marker when it was passed to `run_task()` explicitly, when the +task was created directly through the `TaskStore`, or when the server runs in stateless +mode (each request gets a fresh session, so tasks must remain reachable across requests). +Such tasks are accessible to any requestor that presents the exact task ID, and are never +included in `tasks/list` responses because the server cannot tell which session they +belong to. Treat these task IDs as capabilities: generate them with enough entropy that +they cannot be guessed, share them only with the intended recipient, and prefer short +TTLs. Passing an explicit `task_id` to `run_task()` is deprecated for this reason. + +To scope tasks to something other than the session — for example a user identity from your +authorization layer — register your own handlers with `@server.experimental.get_task()`, +`@server.experimental.get_task_result()`, `@server.experimental.list_tasks()`, and +`@server.experimental.cancel_task()` instead of relying on the defaults. + ## Tool Declaration Tools declare task support via the `execution.taskSupport` field: diff --git a/src/mcp/server/experimental/__init__.py b/src/mcp/server/experimental/__init__.py index 824bb8b8be..91c6dcf3e8 100644 --- a/src/mcp/server/experimental/__init__.py +++ b/src/mcp/server/experimental/__init__.py @@ -8,4 +8,5 @@ - mcp.server.experimental.task_support.TaskSupport - mcp.server.experimental.task_result_handler.TaskResultHandler - mcp.server.experimental.request_context.Experimental +- mcp.server.experimental.task_scope (session scoping of task IDs) """ diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 78e75beb6a..0d69836355 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -7,11 +7,15 @@ WARNING: These APIs are experimental and may change without notice. """ +import warnings from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import Any +from typing import Any, overload + +from typing_extensions import deprecated from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.experimental.task_scope import scoped_task_id from mcp.server.experimental.task_support import TaskSupport from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError @@ -29,6 +33,14 @@ Tool, ) +EXPLICIT_TASK_ID_DEPRECATION = ( + "Passing an explicit task_id to run_task is deprecated. A task created with an " + "explicit ID is not associated with the session that created it: any requestor " + "that presents the ID can read its status and result or cancel it, and it never " + "appears in tasks/list. Omit task_id to let the SDK generate an ID associated " + "with the creating session." +) + @dataclass class Experimental: @@ -143,6 +155,25 @@ def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: return False return True + @overload + async def run_task( + self, + work: Callable[[ServerTaskContext], Awaitable[Result]], + *, + task_id: None = None, + model_immediate_response: str | None = None, + ) -> CreateTaskResult: ... + + @overload + @deprecated(EXPLICIT_TASK_ID_DEPRECATION) + async def run_task( + self, + work: Callable[[ServerTaskContext], Awaitable[Result]], + *, + task_id: str, + model_immediate_response: str | None = None, + ) -> CreateTaskResult: ... + async def run_task( self, work: Callable[[ServerTaskContext], Awaitable[Result]], @@ -167,9 +198,17 @@ async def run_task( When work() returns a Result, the task is auto-completed with that result. If work() raises an exception, the task is auto-failed. + Generated task IDs embed the session's task scope so that the default + task handlers only serve the task to the session that created it. An + explicitly provided `task_id` is used verbatim and is not associated + with the session, so any session can access it through the default + handlers; passing one is deprecated for that reason. + Args: work: Async function that does the actual work - task_id: Optional task ID (generated if not provided) + task_id: Deprecated. Optional task ID, used verbatim and not + associated with the creating session. Omit it to let the SDK + generate one. model_immediate_response: Optional string to include in _meta as io.modelcontextprotocol/model-immediate-response @@ -196,6 +235,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: WARNING: This API is experimental and may change without notice. """ + if task_id is not None: + warnings.warn(EXPLICIT_TASK_ID_DEPRECATION, DeprecationWarning, stacklevel=2) if self._task_support is None: raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.") if self._session is None: @@ -210,6 +251,11 @@ async def work(task: ServerTaskContext) -> CallToolResult: # Access task_group via TaskSupport - raises if not in run() context task_group = support.task_group + if task_id is None: + session_scope = self._session.experimental.task_session_scope + if session_scope is not None: + task_id = scoped_task_id(session_scope) + task = await support.store.create_task(self.task_metadata, task_id) task_ctx = ServerTaskContext( diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py index 4842da5175..c118537fa2 100644 --- a/src/mcp/server/experimental/session_features.py +++ b/src/mcp/server/experimental/session_features.py @@ -40,6 +40,12 @@ class ExperimentalServerSessionFeatures: def __init__(self, session: "ServerSession") -> None: self._session = session + # Opaque marker identifying this session for task scoping. Assigned by + # TaskSupport.configure_session(). Task IDs generated by run_task() + # embed it so the default task handlers can restrict task access to + # the session that created the task. None means tasks created on this + # session are not associated with it (e.g. stateless servers). + self.task_session_scope: str | None = None async def get_task(self, task_id: str) -> types.GetTaskResult: """ diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 0b869216e8..1cf7f69749 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -46,6 +46,11 @@ class TaskResultHandler: 4. Blocks until task reaches terminal state 5. Returns the final result + Prefer `server.experimental.enable_tasks()`, whose default tasks/result + handler wraps `handle()` and only serves tasks created by the requesting + session. A custom handler that calls `handle()` directly is responsible + for deciding which requestors may access which tasks. + Usage: # Create handler with store and queue handler = TaskResultHandler(task_store, message_queue) @@ -55,9 +60,6 @@ class TaskResultHandler: async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: ctx = server.request_context return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) """ def __init__( diff --git a/src/mcp/server/experimental/task_scope.py b/src/mcp/server/experimental/task_scope.py new file mode 100644 index 0000000000..c33cf55725 --- /dev/null +++ b/src/mcp/server/experimental/task_scope.py @@ -0,0 +1,75 @@ +""" +Session scoping for experimental task identifiers. + +Task IDs generated by `run_task()` embed an opaque, per-session marker (the +"session scope") so that the default task handlers can tell which session +created a task. The default handlers for tasks/get, tasks/result, tasks/list, +and tasks/cancel only operate on tasks created by the requesting session. + +Task IDs without a session scope (explicitly provided IDs, IDs created +directly through a TaskStore, or IDs created in stateless mode) have no known +creator. They can be used with tasks/get, tasks/result, and tasks/cancel from +any session - possession of the ID is what grants access - but they are never +included in tasks/list responses. + +WARNING: These APIs are experimental and may change without notice. +""" + +import re +from uuid import uuid4 + +__all__ = [ + "new_session_scope", + "scoped_task_id", + "session_scope_of", + "task_in_session_scope", + "task_listable_in_session_scope", +] + +# A scoped task ID has the form "<32 hex chars>:". Both halves must +# match exactly so that explicitly chosen task IDs are never mistaken for +# scoped ones. \Z rather than $ so a trailing newline cannot match. +_SCOPED_TASK_ID = re.compile( + r"\A(?P[0-9a-f]{32}):" + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\Z" +) + + +def new_session_scope() -> str: + """Create a new opaque session scope token.""" + return uuid4().hex + + +def scoped_task_id(session_scope: str) -> str: + """Generate a task ID associated with the given session scope.""" + return f"{session_scope}:{uuid4()}" + + +def session_scope_of(task_id: str) -> str | None: + """Return the session scope embedded in a task ID, or None if it has none.""" + match = _SCOPED_TASK_ID.match(task_id) + return match.group("scope") if match else None + + +def task_in_session_scope(task_id: str, session_scope: str | None) -> bool: + """Whether a task may be used by a requestor with the given session scope. + + Used by tasks/get, tasks/result, and tasks/cancel. A task whose ID carries + no session scope has no known creator, so possession of the ID is what + grants access to it: it can be used from any session. + """ + embedded = session_scope_of(task_id) + return embedded is None or embedded == session_scope + + +def task_listable_in_session_scope(task_id: str, session_scope: str | None) -> bool: + """Whether a task may be included in a tasks/list response for the given session scope. + + Used by tasks/list. Listing is stricter than access by ID: a task is only + listed to the session that created it. Tasks with no session scope are + never listed because they have no known creator, and requestors with no + session scope are never shown any tasks because the server cannot tell + them apart. + """ + embedded = session_scope_of(task_id) + return embedded is not None and embedded == session_scope diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index dbb2ed6d2b..8e91faf73b 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -13,6 +13,7 @@ from anyio.abc import TaskGroup from mcp.server.experimental.task_result_handler import TaskResultHandler +from mcp.server.experimental.task_scope import new_session_scope from mcp.server.session import ServerSession from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue @@ -83,7 +84,7 @@ async def run(self) -> AsyncIterator[None]: finally: self._task_group = None - def configure_session(self, session: ServerSession) -> None: + def configure_session(self, session: ServerSession, *, stateless: bool = False) -> None: """ Configure a session for task support. @@ -91,12 +92,22 @@ def configure_session(self, session: ServerSession) -> None: responses to queued requests (elicitation, sampling) are routed back to the waiting resolvers. + It also assigns the session a task session scope. Task IDs generated + by `run_task()` embed this scope, and the default task handlers only + operate on tasks created by the requesting session. Stateless sessions + are not assigned a scope: each request runs on a fresh session, so a + task created by one request could never be retrieved by a later one if + tasks were bound to the session that created them. + Called automatically by Server.run() for each new session. Args: session: The session to configure + stateless: Whether the session belongs to a stateless server run """ session.add_response_router(self.handler) + if not stateless and session.experimental.task_session_scope is None: + session.experimental.task_session_scope = new_session_scope() @classmethod def in_memory(cls) -> "TaskSupport": diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 42353e4ea0..737d6bb2cd 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -9,6 +9,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING +from mcp.server.experimental.task_scope import task_in_session_scope, task_listable_in_session_scope from mcp.server.experimental.task_support import TaskSupport from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import McpError @@ -31,6 +32,7 @@ ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, + Task, TasksCallCapability, TasksCancelCapability, TasksListCapability, @@ -125,8 +127,38 @@ def enable_tasks( return self._task_support + def _requestor_session_scope(self) -> str | None: + """Return the task session scope of the session making the current request.""" + return self._server.request_context.session.experimental.task_session_scope + + def _require_task_in_requestor_scope(self, task_id: str) -> None: + """Reject task IDs that belong to a different session. + + Task IDs generated by `run_task()` embed the creating session's + scope. The default handlers treat a task created by another session + exactly like a task that does not exist, so a requestor cannot tell + whether such a task exists. Task IDs without an embedded scope are + accepted from any session. + + Raises: + McpError: With INVALID_PARAMS if the task belongs to another session. + """ + if not task_in_session_scope(task_id, self._requestor_session_scope()): + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {task_id}", + ) + ) + def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" + """Register default handlers for task operations. + + Each default handler only operates on tasks created by the requesting + session (see `_require_task_in_requestor_scope`), and tasks/list only + returns the requesting session's own tasks (see + `task_listable_in_session_scope`). + """ assert self._task_support is not None support = self._task_support @@ -134,6 +166,7 @@ def _register_default_task_handlers(self) -> None: if GetTaskRequest not in self._request_handlers: async def _default_get_task(req: GetTaskRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) task = await support.store.get_task(req.params.taskId) if task is None: raise McpError( @@ -160,6 +193,7 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: if GetTaskPayloadRequest not in self._request_handlers: async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) ctx = self._server.request_context result = await support.handler.handle(req, ctx.session, ctx.request_id) return ServerResult(result) @@ -170,9 +204,26 @@ async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: if ListTasksRequest not in self._request_handlers: async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) - return ServerResult(ListTasksResult(tasks=tasks, nextCursor=next_cursor)) + requestor_scope = self._requestor_session_scope() + if requestor_scope is None: + # The server cannot tell this requestor apart from any + # other, so there are no tasks it can be shown. + return ServerResult(ListTasksResult(tasks=[])) + # Return every task that belongs to the requesting session in + # a single page. The store's pagination cursor is never sent + # to the requestor: it is derived from the unfiltered listing, + # so it could identify a task belonging to a different + # session. For the same reason the request's cursor is not + # forwarded to the store. + own_tasks: list[Task] = [] + cursor: str | None = None + while True: + page, cursor = await support.store.list_tasks(cursor) + own_tasks.extend( + task for task in page if task_listable_in_session_scope(task.taskId, requestor_scope) + ) + if cursor is None: + return ServerResult(ListTasksResult(tasks=own_tasks)) self._request_handlers[ListTasksRequest] = _default_list_tasks @@ -180,6 +231,7 @@ async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: if CancelTaskRequest not in self._request_handlers: async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) result = await cancel_task(support.store, req.params.taskId) return ServerResult(result) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 2dd1a8277a..7d925de32b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -667,7 +667,7 @@ async def run( # Configure task support for this session if enabled task_support = self._experimental_handlers.task_support if self._experimental_handlers else None if task_support is not None: - task_support.configure_session(session) + task_support.configure_session(session, stateless=stateless) await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 7209ed412a..64a1dabb04 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -506,13 +506,14 @@ async def run_server() -> None: # Create a task directly in the store for testing task = await store.create_task(TaskMetadata(ttl=60000)) - # Test list_tasks (default handler) + # Test list_tasks (default handler). Tasks created directly in the + # store have no session scope, so they are reachable by ID but not + # included in tasks/list (see test_task_scope.py). list_result = await client_session.send_request( ClientRequest(ListTasksRequest()), ListTasksResult, ) - assert len(list_result.tasks) == 1 - assert list_result.tasks[0].taskId == task.taskId + assert list_result.tasks == [] # Test get_task (default handler - found) get_result = await client_session.send_request( diff --git a/tests/experimental/tasks/server/test_task_scope.py b/tests/experimental/tasks/server/test_task_scope.py new file mode 100644 index 0000000000..c13b728a86 --- /dev/null +++ b/tests/experimental/tasks/server/test_task_scope.py @@ -0,0 +1,150 @@ +"""Unit tests for the task session-scope helpers. + +A session scope is an opaque marker assigned to each session by +TaskSupport.configure_session(). Task IDs generated by run_task() embed it so +the default task handlers can tell which session created a task. See +test_task_visibility.py for the end-to-end behaviour these helpers produce. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import anyio +import pytest + +from mcp.server import Server +from mcp.server.experimental.task_scope import ( + new_session_scope, + scoped_task_id, + session_scope_of, + task_in_session_scope, + task_listable_in_session_scope, +) +from mcp.server.experimental.task_support import TaskSupport +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage + + +def test_new_session_scope_is_unique() -> None: + assert new_session_scope() != new_session_scope() + + +def test_scoped_task_id_round_trips_its_scope() -> None: + scope = new_session_scope() + + task_id = scoped_task_id(scope) + + assert session_scope_of(task_id) == scope + + +def test_scoped_task_ids_are_unique_within_a_scope() -> None: + scope = new_session_scope() + + assert scoped_task_id(scope) != scoped_task_id(scope) + + +@pytest.mark.parametrize( + "task_id", + [ + "plain-task-id", + "550e8400-e29b-41d4-a716-446655440000", # bare uuid4 + "", + # Right shape but the scope half is not 32 hex chars. + "not-a-scope:550e8400-e29b-41d4-a716-446655440000", + # Right scope half but the suffix is not a uuid4. + "0123456789abcdef0123456789abcdef:not-a-uuid", + # Uppercase hex is not produced by new_session_scope(). + "0123456789ABCDEF0123456789ABCDEF:550e8400-e29b-41d4-a716-446655440000", + ], +) +def test_session_scope_of_returns_none_for_unscoped_ids(task_id: str) -> None: + assert session_scope_of(task_id) is None + + +def test_a_scoped_task_is_usable_only_from_the_scope_that_created_it() -> None: + scope = new_session_scope() + task_id = scoped_task_id(scope) + + assert task_in_session_scope(task_id, scope) is True + assert task_in_session_scope(task_id, new_session_scope()) is False + assert task_in_session_scope(task_id, None) is False + + +def test_an_unscoped_task_is_usable_from_any_scope() -> None: + assert task_in_session_scope("plain-task-id", new_session_scope()) is True + assert task_in_session_scope("plain-task-id", None) is True + + +def test_a_scoped_task_is_listable_only_in_the_scope_that_created_it() -> None: + scope = new_session_scope() + task_id = scoped_task_id(scope) + + assert task_listable_in_session_scope(task_id, scope) is True + assert task_listable_in_session_scope(task_id, new_session_scope()) is False + assert task_listable_in_session_scope(task_id, None) is False + + +def test_an_unscoped_task_is_never_listable() -> None: + assert task_listable_in_session_scope("plain-task-id", new_session_scope()) is False + assert task_listable_in_session_scope("plain-task-id", None) is False + + +@asynccontextmanager +async def _make_session() -> AsyncIterator[ServerSession]: + """Create a ServerSession suitable for inspecting configure_session().""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + options = InitializationOptions( + server_name="test", + server_version="0", + capabilities=Server("test").get_capabilities(NotificationOptions(), {}), + ) + async with ( + server_to_client_receive, + client_to_server_send, + ServerSession(client_to_server_receive, server_to_client_send, options) as session, + ): + yield session + + +@pytest.mark.anyio +async def test_configure_session_assigns_a_scope() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + assert session.experimental.task_session_scope is None + + support.configure_session(session) + + assert session.experimental.task_session_scope is not None + + +@pytest.mark.anyio +async def test_configure_session_assigns_distinct_scopes_per_session() -> None: + support = TaskSupport.in_memory() + async with _make_session() as first, _make_session() as second: + support.configure_session(first) + support.configure_session(second) + + assert first.experimental.task_session_scope != second.experimental.task_session_scope + + +@pytest.mark.anyio +async def test_configure_session_is_idempotent() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + support.configure_session(session) + scope = session.experimental.task_session_scope + support.configure_session(session) + + assert session.experimental.task_session_scope == scope + + +@pytest.mark.anyio +async def test_configure_session_assigns_no_scope_to_stateless_sessions() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + support.configure_session(session, stateless=True) + + assert session.experimental.task_session_scope is None diff --git a/tests/experimental/tasks/server/test_task_visibility.py b/tests/experimental/tasks/server/test_task_visibility.py new file mode 100644 index 0000000000..2ee16399ea --- /dev/null +++ b/tests/experimental/tasks/server/test_task_visibility.py @@ -0,0 +1,321 @@ +"""End-to-end tests for which clients can see and control a task. + +Every test runs a real server and one or more in-memory client sessions. A +task started with run_task() belongs to the client session that started it: +that session can poll it, list it, and cancel it, while every other session +is told the task does not exist. Tasks whose IDs carry no session marker +(explicitly chosen IDs, or tasks on stateless servers) are usable by any +session that knows the ID, but are never listed. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import AsyncExitStack +from typing import Any + +import anyio +import pytest +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolResult, + CreateTaskResult, + ListTasksResult, + TextContent, + Tool, + ToolExecution, +) + +# The `connect` fixture: each call opens a new client session against the test server. +Connect = Callable[..., Awaitable[ClientSession]] + +# Enough tasks that the bundled in-memory store needs more than one page (of 10) +# to list them, so listings that span store pages are exercised. +MORE_TASKS_THAN_ONE_STORE_PAGE = 11 + + +def build_task_server(store: TaskStore | None = None) -> Server: + """Build a server exposing three task tools. + + - "greet" finishes immediately and returns a greeting. + - "long_running_job" keeps running until the server shuts down. + - "nightly_export" is a singleton job: every invocation uses the + explicitly chosen task ID "the-nightly-export". + """ + server = Server("task-visibility-test-server") + server.experimental.enable_tasks(store=store) + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name=name, + description=name, + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + for name in ("greet", "long_running_job", "nightly_export") + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + async def greet(task: ServerTaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Hello, {arguments['name']}!")]) + + async def long_running_job(task: ServerTaskContext) -> CallToolResult: + await anyio.sleep_forever() + raise AssertionError("unreachable") # pragma: no cover + + async def nightly_export(task: ServerTaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text="exported")]) + + run_task = server.request_context.experimental.run_task + if name == "nightly_export": + return await run_task(nightly_export, task_id="the-nightly-export") + return await run_task(greet if name == "greet" else long_running_job) + + return server + + +async def open_client( + server: Server, task_group: TaskGroup, stack: AsyncExitStack, *, stateless: bool = False +) -> ClientSession: + """Connect a new client session to `server` over in-memory streams.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + stateless=stateless, + ) + + task_group.start_soon(run_server) + client = await stack.enter_async_context(ClientSession(server_to_client_receive, client_to_server_send)) + await client.initialize() + return client + + +@pytest.fixture +def task_server() -> Server: + return build_task_server() + + +@pytest.fixture +async def connect(task_server: Server) -> AsyncIterator[Connect]: + """A factory that opens a new client session against the test server on each call.""" + async with anyio.create_task_group() as task_group, AsyncExitStack() as stack: + + async def _connect(*, stateless: bool = False) -> ClientSession: + return await open_client(task_server, task_group, stack, stateless=stateless) + + yield _connect + task_group.cancel_scope.cancel() + + +async def start_task(client: ClientSession, tool: str = "long_running_job", **arguments: Any) -> str: + """Start `tool` as a task and return the new task's ID.""" + result = await client.experimental.call_tool_as_task(tool, arguments) + return result.task.taskId + + +async def wait_until_finished(client: ClientSession, task_id: str) -> None: + """Poll the task until it reaches a terminal status.""" + with anyio.fail_after(5): + async for _ in client.experimental.poll_task(task_id): + pass + + +async def listed_task_ids(client: ClientSession) -> list[str]: + """Return the IDs of every task the server lists for this client.""" + return [task.taskId for task in (await client.experimental.list_tasks()).tasks] + + +# --- What the client that started a task can do with it --- + + +@pytest.mark.anyio +async def test_a_client_can_poll_its_own_task_to_completion_and_read_the_result(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client, "greet", name="Ada") + await wait_until_finished(client, task_id) + + result = await client.experimental.get_task_result(task_id, CallToolResult) + + assert result.content == [TextContent(type="text", text="Hello, Ada!")] + + +@pytest.mark.anyio +async def test_a_client_sees_its_own_task_when_listing_tasks(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client) + + listed = await listed_task_ids(client) + + assert listed == [task_id] + + +@pytest.mark.anyio +async def test_a_client_can_cancel_its_own_task(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client) + + cancelled = await client.experimental.cancel_task(task_id) + + assert cancelled.status == "cancelled" + + +# --- What a client cannot do with a task started by another client --- + + +@pytest.mark.anyio +async def test_a_client_cannot_get_the_status_of_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task(task_id) + + +@pytest.mark.anyio +async def test_a_client_cannot_get_the_result_of_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator, "greet", name="Ada") + await wait_until_finished(creator, task_id) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task_result(task_id, CallToolResult) + + +@pytest.mark.anyio +async def test_a_client_cannot_cancel_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.cancel_task(task_id) + + # The task is unaffected. + assert (await creator.experimental.get_task(task_id)).status == "working" + + +@pytest.mark.anyio +async def test_a_client_does_not_see_another_clients_task_when_listing_tasks(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + await start_task(creator) + + listed = await listed_task_ids(other_client) + + assert listed == [] + + +@pytest.mark.anyio +async def test_each_client_lists_only_its_own_tasks(connect: Connect) -> None: + first_client = await connect() + second_client = await connect() + first_task = await start_task(first_client) + second_task = await start_task(second_client) + + assert await listed_task_ids(first_client) == [first_task] + assert await listed_task_ids(second_client) == [second_task] + + +@pytest.mark.anyio +async def test_listing_tasks_reveals_nothing_about_other_clients_tasks_however_many_there_are( + connect: Connect, +) -> None: + """The listing must not identify other clients' tasks through any field, including the pagination cursor.""" + creator = await connect() + other_client = await connect() + for _ in range(MORE_TASKS_THAN_ONE_STORE_PAGE): + await start_task(creator) + + listing = await other_client.experimental.list_tasks() + + assert listing == ListTasksResult(tasks=[], nextCursor=None) + + +@pytest.mark.anyio +async def test_a_client_with_more_than_one_store_page_of_tasks_lists_all_of_them(connect: Connect) -> None: + client = await connect() + started = {await start_task(client) for _ in range(MORE_TASKS_THAN_ONE_STORE_PAGE)} + + listing = await client.experimental.list_tasks() + + assert {task.taskId for task in listing.tasks} == started + assert listing.nextCursor is None + + +# --- Tasks that do not belong to any client session --- + + +@pytest.mark.anyio +# Choosing the task ID instead of letting the SDK generate one is deprecated for +# exactly the behaviour this test demonstrates: the task is not tied to the +# session that created it. +@pytest.mark.filterwarnings("ignore:Passing an explicit task_id") +async def test_a_task_whose_id_was_chosen_by_the_server_is_accessible_to_every_client(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + await wait_until_finished(creator, await start_task(creator, "nightly_export")) + + status = await other_client.experimental.get_task("the-nightly-export") + + assert status.status == "completed" + + +@pytest.mark.anyio +async def test_a_stateless_server_serves_a_task_to_any_session_that_knows_its_id(connect: Connect) -> None: + first_session = await connect(stateless=True) + second_session = await connect(stateless=True) + task_id = await start_task(first_session, "greet", name="Ada") + await wait_until_finished(second_session, task_id) + + result = await second_session.experimental.get_task_result(task_id, CallToolResult) + + assert result.content == [TextContent(type="text", text="Hello, Ada!")] + + +@pytest.mark.anyio +async def test_a_stateless_server_lists_no_tasks(connect: Connect) -> None: + session = await connect(stateless=True) + await start_task(session) + + listed = await listed_task_ids(session) + + assert listed == [] + + +# --- The behaviour does not depend on the bundled in-memory store --- + + +@pytest.mark.anyio +async def test_clients_are_isolated_when_the_server_uses_a_custom_task_store() -> None: + class CustomTaskStore(InMemoryTaskStore): + """A stand-in for a user-provided TaskStore implementation.""" + + server = build_task_server(store=CustomTaskStore()) + + async with anyio.create_task_group() as task_group, AsyncExitStack() as stack: + creator = await open_client(server, task_group, stack) + other_client = await open_client(server, task_group, stack) + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task(task_id) + + assert (await creator.experimental.get_task(task_id)).status == "working" + task_group.cancel_scope.cancel() diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py index 5fa5da81af..e52ec4403a 100644 --- a/tests/experimental/tasks/test_request_context.py +++ b/tests/experimental/tasks/test_request_context.py @@ -3,6 +3,7 @@ import pytest from mcp.server.experimental.request_context import Experimental +from mcp.server.experimental.task_context import ServerTaskContext from mcp.shared.exceptions import McpError from mcp.types import ( METHOD_NOT_FOUND, @@ -11,6 +12,7 @@ TASK_REQUIRED, ClientCapabilities, ClientTasksCapability, + Result, TaskMetadata, Tool, ToolExecution, @@ -164,3 +166,30 @@ def test_can_use_tool_forbidden_without_task_support() -> None: def test_can_use_tool_none_without_task_support() -> None: exp = Experimental(_client_capabilities=ClientCapabilities()) assert exp.can_use_tool(None) is True + + +@pytest.mark.anyio +async def test_run_task_with_an_explicit_task_id_emits_a_deprecation_warning() -> None: + """An explicitly provided task ID is not associated with the creating session, so passing one is deprecated.""" + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + + async def work(task: ServerTaskContext) -> Result: + raise AssertionError("unreachable") # pragma: no cover + + with pytest.warns(DeprecationWarning, match="not associated with the session"): + # Task support is not configured, so the call fails after the + # deprecated argument has been reported. + with pytest.raises(RuntimeError, match="Task support not enabled"): + # The deliberate use of the deprecated overload is the point of this test. + await exp.run_task(work, task_id="explicitly-chosen") # pyright: ignore[reportDeprecated] + + +@pytest.mark.anyio +async def test_run_task_without_a_task_id_does_not_warn() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + + async def work(task: ServerTaskContext) -> Result: + raise AssertionError("unreachable") # pragma: no cover + + with pytest.raises(RuntimeError, match="Task support not enabled"): + await exp.run_task(work)