diff --git a/README.v2.md b/README.v2.md index d0851c04e5..a888f21bda 100644 --- a/README.v2.md +++ b/README.v2.md @@ -2489,7 +2489,6 @@ MCP servers declare capabilities during initialization: ## Documentation - [API Reference](https://modelcontextprotocol.github.io/python-sdk/api/) -- [Experimental Features (Tasks)](https://modelcontextprotocol.github.io/python-sdk/experimental/tasks/) - [Model Context Protocol documentation](https://modelcontextprotocol.io) - [Model Context Protocol specification](https://modelcontextprotocol.io/specification/latest) - [Officially supported servers](https://github.com/modelcontextprotocol/servers) diff --git a/docs/experimental/index.md b/docs/experimental/index.md deleted file mode 100644 index c97fe2a3d6..0000000000 --- a/docs/experimental/index.md +++ /dev/null @@ -1,42 +0,0 @@ -# Experimental Features - -!!! warning "Experimental APIs" - - The features in this section are experimental and may change without notice. - They track the evolving MCP specification and are not yet stable. - -This section documents experimental features in the MCP Python SDK. These features -implement draft specifications that are still being refined. - -## Available Experimental Features - -### [Tasks](tasks.md) - -Tasks enable asynchronous execution of MCP operations. Instead of waiting for a -long-running operation to complete, the server returns a task reference immediately. -Clients can then poll for status updates and retrieve results when ready. - -Tasks are useful for: - -- **Long-running computations** that would otherwise block -- **Batch operations** that process many items -- **Interactive workflows** that require user input (elicitation) or LLM assistance (sampling) - -## Using Experimental APIs - -Experimental features are accessed via the `.experimental` property: - -```python -# Server-side: enable task support (auto-registers default handlers) -server = Server(name="my-server") -server.experimental.enable_tasks() - -# Client-side -result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) -``` - -## Providing Feedback - -Since these features are experimental, feedback is especially valuable. If you encounter -issues or have suggestions, please open an issue on the -[python-sdk repository](https://github.com/modelcontextprotocol/python-sdk/issues). diff --git a/docs/experimental/tasks-client.md b/docs/experimental/tasks-client.md deleted file mode 100644 index 0374ed86b5..0000000000 --- a/docs/experimental/tasks-client.md +++ /dev/null @@ -1,361 +0,0 @@ -# Client Task Usage - -!!! warning "Experimental" - - Tasks are an experimental feature. The API may change without notice. - -This guide covers calling task-augmented tools from clients, handling the `input_required` status, and advanced patterns like receiving task requests from servers. - -## Quick Start - -Call a tool as a task and poll for the result: - -```python -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - -async with ClientSession(read, write) as session: - await session.initialize() - - # Call tool as task - result = await session.experimental.call_tool_as_task( - "process_data", - {"input": "hello"}, - ttl=60000, - ) - task_id = result.task.taskId - - # Poll until complete - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status} - {status.statusMessage or ''}") - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") -``` - -## Calling Tools as Tasks - -Use `call_tool_as_task()` to invoke a tool with task augmentation: - -```python -result = await session.experimental.call_tool_as_task( - "my_tool", # Tool name - {"arg": "value"}, # Arguments - ttl=60000, # Time-to-live in milliseconds - meta={"key": "val"}, # Optional metadata -) - -task_id = result.task.taskId -print(f"Task: {task_id}, Status: {result.task.status}") -``` - -The response is a `CreateTaskResult` containing: - -- `task.taskId` - Unique identifier for polling -- `task.status` - Initial status (usually `"working"`) -- `task.pollInterval` - Suggested polling interval (milliseconds) -- `task.ttl` - Time-to-live for results -- `task.createdAt` - Creation timestamp - -## Polling with poll_task - -The `poll_task()` async iterator polls until the task reaches a terminal state: - -```python -async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - if status.statusMessage: - print(f"Progress: {status.statusMessage}") -``` - -It automatically: - -- Respects the server's suggested `pollInterval` -- Stops when status is `completed`, `failed`, or `cancelled` -- Yields each status for progress display - -### Handling input_required - -When a task needs user input (elicitation), it transitions to `input_required`. You must call `get_task_result()` to receive and respond to the elicitation: - -```python -async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - if status.status == "input_required": - # This delivers the elicitation and waits for completion - final = await session.experimental.get_task_result(task_id, CallToolResult) - break -``` - -The elicitation callback (set during session creation) handles the actual user interaction. - -## Elicitation Callbacks - -To handle elicitation requests from the server, provide a callback when creating the session: - -```python -from mcp.types import ElicitRequestParams, ElicitResult - -async def handle_elicitation(context, params: ElicitRequestParams) -> ElicitResult: - # Display the message to the user - print(f"Server asks: {params.message}") - - # Collect user input (this is a simplified example) - response = input("Your response (y/n): ") - confirmed = response.lower() == "y" - - return ElicitResult( - action="accept", - content={"confirm": confirmed}, - ) - -async with ClientSession( - read, - write, - elicitation_callback=handle_elicitation, -) as session: - await session.initialize() - # ... call tasks that may require elicitation -``` - -## Sampling Callbacks - -Similarly, handle sampling requests with a callback: - -```python -from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent - -async def handle_sampling(context, params: CreateMessageRequestParams) -> CreateMessageResult: - # In a real implementation, call your LLM here - prompt = params.messages[-1].content.text if params.messages else "" - - # Return a mock response - return CreateMessageResult( - role="assistant", - content=TextContent(type="text", text=f"Response to: {prompt}"), - model="my-model", - ) - -async with ClientSession( - read, - write, - sampling_callback=handle_sampling, -) as session: - # ... -``` - -## Retrieving Results - -Once a task completes, retrieve the result: - -```python -if status.status == "completed": - result = await session.experimental.get_task_result(task_id, CallToolResult) - for content in result.content: - if hasattr(content, "text"): - print(content.text) - -elif status.status == "failed": - print(f"Task failed: {status.statusMessage}") - -elif status.status == "cancelled": - print("Task was cancelled") -``` - -The result type matches the original request: - -- `tools/call` → `CallToolResult` -- `sampling/createMessage` → `CreateMessageResult` -- `elicitation/create` → `ElicitResult` - -## Cancellation - -Cancel a running task: - -```python -cancel_result = await session.experimental.cancel_task(task_id) -print(f"Cancelled, status: {cancel_result.status}") -``` - -Note: Cancellation is cooperative—the server must check for and handle cancellation. - -## Listing Tasks - -View all tasks on the server: - -```python -result = await session.experimental.list_tasks() -for task in result.tasks: - print(f"{task.taskId}: {task.status}") - -# Handle pagination -while result.nextCursor: - result = await session.experimental.list_tasks(cursor=result.nextCursor) - for task in result.tasks: - print(f"{task.taskId}: {task.status}") -``` - -## Advanced: Client as Task Receiver - -Servers can send task-augmented requests to clients. This is useful when the server needs the client to perform async work (like complex sampling or user interaction). - -### Declaring Client Capabilities - -Register task handlers to declare what task-augmented requests your client accepts: - -```python -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.types import ( - CreateTaskResult, GetTaskResult, GetTaskPayloadResult, - TaskMetadata, ElicitRequestParams, -) -from mcp.shared.experimental.tasks import InMemoryTaskStore - -# Client-side task store -client_store = InMemoryTaskStore() - -async def handle_augmented_elicitation(context, params: ElicitRequestParams, task_metadata: TaskMetadata): - """Handle task-augmented elicitation from server.""" - # Create a task for this elicitation - task = await client_store.create_task(task_metadata) - - # Start async work (e.g., show UI, wait for user) - async def complete_elicitation(): - # ... do async work ... - result = ElicitResult(action="accept", content={"confirm": True}) - await client_store.store_result(task.taskId, result) - await client_store.update_task(task.taskId, status="completed") - - context.session._task_group.start_soon(complete_elicitation) - - # Return task reference immediately - return CreateTaskResult(task=task) - -async def handle_get_task(context, params): - """Handle tasks/get from server.""" - task = await client_store.get_task(params.taskId) - return GetTaskResult( - taskId=task.taskId, - status=task.status, - statusMessage=task.statusMessage, - createdAt=task.createdAt, - lastUpdatedAt=task.lastUpdatedAt, - ttl=task.ttl, - pollInterval=100, - ) - -async def handle_get_task_result(context, params): - """Handle tasks/result from server.""" - result = await client_store.get_result(params.taskId) - return GetTaskPayloadResult.model_validate(result.model_dump()) - -task_handlers = ExperimentalTaskHandlers( - augmented_elicitation=handle_augmented_elicitation, - get_task=handle_get_task, - get_task_result=handle_get_task_result, -) - -async with ClientSession( - read, - write, - experimental_task_handlers=task_handlers, -) as session: - # Client now accepts task-augmented elicitation from server - await session.initialize() -``` - -This enables flows where: - -1. Client calls a task-augmented tool -2. Server's tool work calls `task.elicit_as_task()` -3. Client receives task-augmented elicitation -4. Client creates its own task, does async work -5. Server polls client's task -6. Eventually both tasks complete - -## Complete Example - -A client that handles all task scenarios: - -```python -import anyio -from mcp.client.session import ClientSession -from mcp.client.stdio import stdio_client -from mcp.types import CallToolResult, ElicitRequestParams, ElicitResult - - -async def elicitation_callback(context, params: ElicitRequestParams) -> ElicitResult: - print(f"\n[Elicitation] {params.message}") - response = input("Confirm? (y/n): ") - return ElicitResult(action="accept", content={"confirm": response.lower() == "y"}) - - -async def main(): - async with stdio_client(command="python", args=["server.py"]) as (read, write): - async with ClientSession( - read, - write, - elicitation_callback=elicitation_callback, - ) as session: - await session.initialize() - - # List available tools - tools = await session.list_tools() - print("Tools:", [t.name for t in tools.tools]) - - # Call a task-augmented tool - print("\nCalling task tool...") - result = await session.experimental.call_tool_as_task( - "confirm_action", - {"action": "delete files"}, - ) - task_id = result.task.taskId - print(f"Task created: {task_id}") - - # Poll and handle input_required - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - if status.status == "input_required": - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") - break - - if status.status == "completed": - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") - - -if __name__ == "__main__": - anyio.run(main) -``` - -## Error Handling - -Handle task errors gracefully: - -```python -from mcp.shared.exceptions import MCPError - -try: - result = await session.experimental.call_tool_as_task("my_tool", args) - task_id = result.task.taskId - - async for status in session.experimental.poll_task(task_id): - if status.status == "failed": - raise RuntimeError(f"Task failed: {status.statusMessage}") - - final = await session.experimental.get_task_result(task_id, CallToolResult) - -except MCPError as e: - print(f"MCP error: {e.message}") -except Exception as e: - print(f"Error: {e}") -``` - -## Next Steps - -- [Server Implementation](tasks-server.md) - Build task-supporting servers -- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md deleted file mode 100644 index b350ee3bb6..0000000000 --- a/docs/experimental/tasks-server.md +++ /dev/null @@ -1,577 +0,0 @@ -# Server Task Implementation - -!!! warning "Experimental" - - Tasks are an experimental feature. The API may change without notice. - -This guide covers implementing task support in MCP servers, from basic setup to advanced patterns like elicitation and sampling within tasks. - -## Quick Start - -The simplest way to add task support: - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED - -server = Server("my-server") -server.experimental.enable_tasks() # Registers all task handlers automatically - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="process_data", - description="Process data asynchronously", - inputSchema={"type": "object", "properties": {"input": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ) - ] - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "process_data": - return await handle_process_data(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) - -async def handle_process_data(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Processing...") - result = arguments.get("input", "").upper() - return CallToolResult(content=[TextContent(type="text", text=result)]) - - return await ctx.experimental.run_task(work) -``` - -That's it. `enable_tasks()` automatically: - -- Creates an in-memory task store -- Registers handlers for `tasks/get`, `tasks/result`, `tasks/list`, `tasks/cancel` -- Updates server capabilities - -## Tool Declaration - -Tools declare task support via the `execution.taskSupport` field: - -```python -from mcp.types import Tool, ToolExecution, TASK_REQUIRED, TASK_OPTIONAL, TASK_FORBIDDEN - -Tool( - name="my_tool", - inputSchema={"type": "object"}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), # or TASK_OPTIONAL, TASK_FORBIDDEN -) -``` - -| Value | Meaning | -|-------|---------| -| `TASK_REQUIRED` | Tool **must** be called as a task | -| `TASK_OPTIONAL` | Tool supports both sync and task execution | -| `TASK_FORBIDDEN` | Tool **cannot** be called as a task (default) | - -Validate the request matches your tool's requirements: - -```python -@server.call_tool() -async def handle_tool(name: str, arguments: dict): - ctx = server.request_context - - if name == "required_task_tool": - ctx.experimental.validate_task_mode(TASK_REQUIRED) # Raises if not task mode - return await handle_as_task(arguments) - - elif name == "optional_task_tool": - if ctx.experimental.is_task: - return await handle_as_task(arguments) - else: - return handle_sync(arguments) -``` - -## The run_task Pattern - -`run_task()` is the recommended way to execute task work: - -```python -async def handle_my_tool(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Your work here - return CallToolResult(content=[TextContent(type="text", text="Done")]) - - return await ctx.experimental.run_task(work) -``` - -**What `run_task()` does:** - -1. Creates a task in the store -2. Spawns your work function in the background -3. Returns `CreateTaskResult` immediately -4. Auto-completes the task when your function returns -5. Auto-fails the task if your function raises - -**The `ServerTaskContext` provides:** - -- `task.task_id` - The task identifier -- `task.update_status(message)` - Update progress -- `task.complete(result)` - Explicitly complete (usually automatic) -- `task.fail(error)` - Explicitly fail -- `task.is_cancelled` - Check if cancellation requested - -## Status Updates - -Keep clients informed of progress: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Starting...") - - for i, item in enumerate(items): - await task.update_status(f"Processing {i+1}/{len(items)}") - await process_item(item) - - await task.update_status("Finalizing...") - return CallToolResult(content=[TextContent(type="text", text="Complete")]) -``` - -Status messages appear in `tasks/get` responses, letting clients show progress to users. - -## Elicitation Within Tasks - -Tasks can request user input via elicitation. This transitions the task to `input_required` status. - -### Form Elicitation - -Collect structured data from the user: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Waiting for confirmation...") - - result = await task.elicit( - message="Delete these files?", - requestedSchema={ - "type": "object", - "properties": { - "confirm": {"type": "boolean"}, - "reason": {"type": "string"}, - }, - "required": ["confirm"], - }, - ) - - if result.action == "accept" and result.content.get("confirm"): - # User confirmed - return CallToolResult(content=[TextContent(type="text", text="Files deleted")]) - else: - # User declined or cancelled - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) -``` - -### URL Elicitation - -Direct users to external URLs for OAuth, payments, or other out-of-band flows: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Waiting for OAuth...") - - result = await task.elicit_url( - message="Please authorize with GitHub", - url="https://github.com/login/oauth/authorize?client_id=...", - elicitation_id="oauth-github-123", - ) - - if result.action == "accept": - # User completed OAuth flow - return CallToolResult(content=[TextContent(type="text", text="Connected to GitHub")]) - else: - return CallToolResult(content=[TextContent(type="text", text="OAuth cancelled")]) -``` - -## Sampling Within Tasks - -Tasks can request LLM completions from the client: - -```python -from mcp.types import SamplingMessage, TextContent - -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Generating response...") - - result = await task.create_message( - messages=[ - SamplingMessage( - role="user", - content=TextContent(type="text", text="Write a haiku about coding"), - ) - ], - max_tokens=100, - ) - - haiku = result.content.text if isinstance(result.content, TextContent) else "Error" - return CallToolResult(content=[TextContent(type="text", text=haiku)]) -``` - -Sampling supports additional parameters: - -```python -result = await task.create_message( - messages=[...], - max_tokens=500, - system_prompt="You are a helpful assistant", - temperature=0.7, - stop_sequences=["\n\n"], - model_preferences=ModelPreferences(hints=[ModelHint(name="claude-3")]), -) -``` - -## Cancellation Support - -Check for cancellation in long-running work: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - for i in range(1000): - if task.is_cancelled: - # Clean up and exit - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - await task.update_status(f"Step {i}/1000") - await process_step(i) - - return CallToolResult(content=[TextContent(type="text", text="Complete")]) -``` - -The SDK's default cancel handler updates the task status. Your work function should check `is_cancelled` periodically. - -## Custom Task Store - -For production, implement `TaskStore` with persistent storage: - -```python -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import Task, TaskMetadata, Result - -class RedisTaskStore(TaskStore): - def __init__(self, redis_client): - self.redis = redis_client - - async def create_task(self, metadata: TaskMetadata, task_id: str | None = None) -> Task: - # Create and persist task - ... - - async def get_task(self, task_id: str) -> Task | None: - # Retrieve task from Redis - ... - - async def update_task(self, task_id: str, status: str | None = None, ...) -> Task: - # Update and persist - ... - - async def store_result(self, task_id: str, result: Result) -> None: - # Store result in Redis - ... - - async def get_result(self, task_id: str) -> Result | None: - # Retrieve result - ... - - # ... implement remaining methods -``` - -Use your custom store: - -```python -store = RedisTaskStore(redis_client) -server.experimental.enable_tasks(store=store) -``` - -## Complete Example - -A server with multiple task-supporting tools: - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import ( - CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, - SamplingMessage, TASK_REQUIRED, -) - -server = Server("task-demo") -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="confirm_action", - description="Requires user confirmation", - inputSchema={"type": "object", "properties": {"action": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ), - Tool( - name="generate_text", - description="Generate text via LLM", - inputSchema={"type": "object", "properties": {"prompt": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ), - ] - - -async def handle_confirm_action(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - action = arguments.get("action", "unknown action") - - async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit( - message=f"Confirm: {action}?", - requestedSchema={ - "type": "object", - "properties": {"confirm": {"type": "boolean"}}, - "required": ["confirm"], - }, - ) - - if result.action == "accept" and result.content.get("confirm"): - return CallToolResult(content=[TextContent(type="text", text=f"Executed: {action}")]) - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - return await ctx.experimental.run_task(work) - - -async def handle_generate_text(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - prompt = arguments.get("prompt", "Hello") - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Generating...") - - result = await task.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], - max_tokens=200, - ) - - text = result.content.text if isinstance(result.content, TextContent) else "Error" - return CallToolResult(content=[TextContent(type="text", text=text)]) - - return await ctx.experimental.run_task(work) - - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "confirm_action": - return await handle_confirm_action(arguments) - elif name == "generate_text": - return await handle_generate_text(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) -``` - -## Error Handling in Tasks - -Tasks handle errors automatically, but you can also fail explicitly: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - try: - result = await risky_operation() - return CallToolResult(content=[TextContent(type="text", text=result)]) - except PermissionError: - await task.fail("Access denied - insufficient permissions") - raise - except TimeoutError: - await task.fail("Operation timed out after 30 seconds") - raise -``` - -When `run_task()` catches an exception, it automatically: - -1. Marks the task as `failed` -2. Sets `statusMessage` to the exception message -3. Propagates the exception (which is caught by the task group) - -For custom error messages, call `task.fail()` before raising. - -## HTTP Transport Example - -For web applications, use the Streamable HTTP transport: - -```python -import uvicorn - -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import ( - CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED, -) - - -server = Server("http-task-server") -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="long_operation", - description="A long-running operation", - inputSchema={"type": "object", "properties": {"duration": {"type": "number"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ) - ] - - -async def handle_long_operation(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - duration = arguments.get("duration", 5) - - async def work(task: ServerTaskContext) -> CallToolResult: - import anyio - for i in range(int(duration)): - await task.update_status(f"Step {i+1}/{int(duration)}") - await anyio.sleep(1) - return CallToolResult(content=[TextContent(type="text", text=f"Completed after {duration}s")]) - - return await ctx.experimental.run_task(work) - - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "long_operation": - return await handle_long_operation(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) - - -if __name__ == "__main__": - uvicorn.run(server.streamable_http_app(), host="127.0.0.1", port=8000) -``` - -## Testing Task Servers - -Test task functionality with the SDK's testing utilities: - -```python -import pytest -import anyio -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - - -@pytest.mark.anyio -async def test_task_tool(): - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream(10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream(10) - - async def run_server(): - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client(): - async with ClientSession(server_to_client_receive, client_to_server_send) as session: - await session.initialize() - - # Call the tool as a task - result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) - task_id = result.task.taskId - assert result.task.status == "working" - - # Poll until complete - async for status in session.experimental.poll_task(task_id): - if status.status in ("completed", "failed"): - break - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - assert len(final.content) > 0 - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) -``` - -## Best Practices - -### Keep Work Functions Focused - -```python -# Good: focused work function -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Validating...") - validate_input(arguments) - - await task.update_status("Processing...") - result = await process_data(arguments) - - return CallToolResult(content=[TextContent(type="text", text=result)]) -``` - -### Check Cancellation in Loops - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - results = [] - for item in large_dataset: - if task.is_cancelled: - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - results.append(await process(item)) - - return CallToolResult(content=[TextContent(type="text", text=str(results))]) -``` - -### Use Meaningful Status Messages - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Connecting to database...") - db = await connect() - - await task.update_status("Fetching records (0/1000)...") - for i, record in enumerate(records): - if i % 100 == 0: - await task.update_status(f"Processing records ({i}/1000)...") - await process(record) - - await task.update_status("Finalizing results...") - return CallToolResult(content=[TextContent(type="text", text="Done")]) -``` - -### Handle Elicitation Responses - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit(message="Continue?", requestedSchema={...}) - - match result.action: - case "accept": - # User accepted, process content - return await process_accepted(result.content) - case "decline": - # User explicitly declined - return CallToolResult(content=[TextContent(type="text", text="User declined")]) - case "cancel": - # User cancelled the elicitation - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) -``` - -## Next Steps - -- [Client Usage](tasks-client.md) - Learn how clients interact with task servers -- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks.md b/docs/experimental/tasks.md deleted file mode 100644 index 2d4d06a025..0000000000 --- a/docs/experimental/tasks.md +++ /dev/null @@ -1,188 +0,0 @@ -# Tasks - -!!! warning "Experimental" - - Tasks are an experimental feature tracking the draft MCP specification. - The API may change without notice. - -Tasks enable asynchronous request handling in MCP. Instead of blocking until an operation completes, the receiver creates a task, returns immediately, and the requestor polls for the result. - -## When to Use Tasks - -Tasks are designed for operations that: - -- Take significant time (seconds to minutes) -- Need progress updates during execution -- Require user input mid-execution (elicitation, sampling) -- Should run without blocking the requestor - -Common use cases: - -- Long-running data processing -- Multi-step workflows with user confirmation -- LLM-powered operations requiring sampling -- OAuth flows requiring user browser interaction - -## Task Lifecycle - -```text - ┌─────────────┐ - │ working │ - └──────┬──────┘ - │ - ┌────────────┼────────────┐ - │ │ │ - ▼ ▼ ▼ - ┌────────────┐ ┌───────────┐ ┌───────────┐ - │ completed │ │ failed │ │ cancelled │ - └────────────┘ └───────────┘ └───────────┘ - ▲ - │ - ┌────────┴────────┐ - │ input_required │◄──────┐ - └────────┬────────┘ │ - │ │ - └────────────────┘ -``` - -| Status | Description | -|--------|-------------| -| `working` | Task is being processed | -| `input_required` | Receiver needs input from requestor (elicitation/sampling) | -| `completed` | Task finished successfully | -| `failed` | Task encountered an error | -| `cancelled` | Task was cancelled by requestor | - -Terminal states (`completed`, `failed`, `cancelled`) are final—tasks cannot transition out of them. - -## Bidirectional Flow - -Tasks work in both directions: - -**Client → Server** (most common): - -```text -Client Server - │ │ - │── tools/call (task) ──────────────>│ Creates task - │<── CreateTaskResult ───────────────│ - │ │ - │── tasks/get ──────────────────────>│ - │<── status: working ────────────────│ - │ │ ... work continues ... - │── tasks/get ──────────────────────>│ - │<── status: completed ──────────────│ - │ │ - │── tasks/result ───────────────────>│ - │<── CallToolResult ─────────────────│ -``` - -**Server → Client** (for elicitation/sampling): - -```text -Server Client - │ │ - │── elicitation/create (task) ──────>│ Creates task - │<── CreateTaskResult ───────────────│ - │ │ - │── tasks/get ──────────────────────>│ - │<── status: working ────────────────│ - │ │ ... user interaction ... - │── tasks/get ──────────────────────>│ - │<── status: completed ──────────────│ - │ │ - │── tasks/result ───────────────────>│ - │<── ElicitResult ───────────────────│ -``` - -## Key Concepts - -### Task Metadata - -When augmenting a request with task execution, include `TaskMetadata`: - -```python -from mcp.types import TaskMetadata - -task = TaskMetadata(ttl=60000) # TTL in milliseconds -``` - -The `ttl` (time-to-live) specifies how long the task and result are retained after completion. - -### Task Store - -Servers persist task state in a `TaskStore`. The SDK provides `InMemoryTaskStore` for development: - -```python -from mcp.shared.experimental.tasks import InMemoryTaskStore - -store = InMemoryTaskStore() -``` - -For production, implement `TaskStore` with a database or distributed cache. - -### Capabilities - -Both servers and clients declare task support through capabilities: - -**Server capabilities:** - -- `tasks.requests.tools.call` - Server accepts task-augmented tool calls - -**Client capabilities:** - -- `tasks.requests.sampling.createMessage` - Client accepts task-augmented sampling -- `tasks.requests.elicitation.create` - Client accepts task-augmented elicitation - -The SDK manages these automatically when you enable task support. - -## Quick Example - -**Server** (simplified API): - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import CallToolResult, TextContent, TASK_REQUIRED - -server = Server("my-server") -server.experimental.enable_tasks() # One-line setup - -@server.call_tool() -async def handle_tool(name: str, arguments: dict): - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext): - await task.update_status("Processing...") - # ... do work ... - return CallToolResult(content=[TextContent(type="text", text="Done!")]) - - return await ctx.experimental.run_task(work) -``` - -**Client:** - -```python -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - -async with ClientSession(read, write) as session: - await session.initialize() - - # Call tool as task - result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) - task_id = result.task.taskId - - # Poll until done - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) -``` - -## Next Steps - -- [Server Implementation](tasks-server.md) - Build task-supporting servers -- [Client Usage](tasks-client.md) - Call and poll tasks from clients diff --git a/docs/migration.md b/docs/migration.md index 8b70885e8d..9850f74cd4 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -595,7 +595,7 @@ The `RequestContext` class has been split to separate shared fields from server- **`RequestContext` changes:** - Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` -- Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` +- Server-specific fields (`lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` **Before (v1):** @@ -861,7 +861,7 @@ server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handl **Key differences:** -- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session` and `lifespan_context` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. - Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). - The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. @@ -872,7 +872,7 @@ All handlers receive `ctx: ServerRequestContext` as the first argument. The seco | v1 decorator | v2 constructor kwarg | `params` type | return type | |---|---|---|---| | `@server.list_tools()` | `on_list_tools` | `PaginatedRequestParams \| None` | `ListToolsResult` | -| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult \| CreateTaskResult` | +| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult` | | `@server.list_resources()` | `on_list_resources` | `PaginatedRequestParams \| None` | `ListResourcesResult` | | `@server.list_resource_templates()` | `on_list_resource_templates` | `PaginatedRequestParams \| None` | `ListResourceTemplatesResult` | | `@server.read_resource()` | `on_read_resource` | `ReadResourceRequestParams` | `ReadResourceResult` | @@ -1039,37 +1039,11 @@ from mcp.server import ServerRequestContext # but None in notification handlers ``` -### Experimental: task handler decorators removed +### Experimental Tasks support removed -The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. +Tasks (SEP-1686) have been removed from the MCP specification and are no longer part of this SDK. The `mcp.client.experimental`, `mcp.server.experimental`, `mcp.shared.experimental`, and `mcp.server.lowlevel.experimental` modules have been removed, along with all `Task*` types, the `tasks` capability fields, `Tool.execution`, and the `experimental` properties on `ClientSession`, `ServerSession`, `Server`, and `ServerRequestContext`. -Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. - -**Before (v1):** - -```python -server = Server("my-server") -server.experimental.enable_tasks() - -@server.experimental.get_task() -async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: - ... -``` - -**After (v2):** - -```python -from mcp.server import Server, ServerRequestContext -from mcp.types import GetTaskRequestParams, GetTaskResult - - -async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - ... - - -server = Server("my-server") -server.experimental.enable_tasks(on_get_task=custom_get_task) -``` +Tasks are expected to return as a separate MCP extension in a future release. ## Deprecations diff --git a/examples/clients/simple-task-client/README.md b/examples/clients/simple-task-client/README.md deleted file mode 100644 index 103be0f1fb..0000000000 --- a/examples/clients/simple-task-client/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Simple Task Client - -A minimal MCP client demonstrating polling for task results over streamable HTTP. - -## Running - -First, start the simple-task server in another terminal: - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -Then run the client: - -```bash -cd examples/clients/simple-task-client -uv run mcp-simple-task-client -``` - -Use `--url` to connect to a different server. - -## What it does - -1. Connects to the server via streamable HTTP -2. Calls the `long_running_task` tool as a task -3. Polls the task status until completion -4. Retrieves and prints the result - -## Expected output - -```text -Available tools: ['long_running_task'] - -Calling tool as a task... -Task created: - Status: working - Starting work... - Status: working - Processing step 1... - Status: working - Processing step 2... - Status: completed - - -Result: Task completed! -``` diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py deleted file mode 100644 index 2fc2cda8d9..0000000000 --- a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .main import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/main.py b/examples/clients/simple-task-client/mcp_simple_task_client/main.py deleted file mode 100644 index f9e555c8e6..0000000000 --- a/examples/clients/simple-task-client/mcp_simple_task_client/main.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Simple task client demonstrating MCP tasks polling over streamable HTTP.""" - -import asyncio - -import click -from mcp import ClientSession -from mcp.client.streamable_http import streamable_http_client -from mcp.types import CallToolResult, TextContent - - -async def run(url: str) -> None: - async with streamable_http_client(url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - # List tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Call the tool as a task - print("\nCalling tool as a task...") - - result = await session.experimental.call_tool_as_task( - "long_running_task", - arguments={}, - ttl=60000, - ) - task_id = result.task.task_id - print(f"Task created: {task_id}") - - status = None - # Poll until done (respects server's pollInterval hint) - async for status in session.experimental.poll_task(task_id): - print(f" Status: {status.status} - {status.status_message or ''}") - - # Check final status - if status and status.status != "completed": - print(f"Task ended with status: {status.status}") - return - - # Get the result - task_result = await session.experimental.get_task_result(task_id, CallToolResult) - content = task_result.content[0] - if isinstance(content, TextContent): - print(f"\nResult: {content.text}") - - -@click.command() -@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") -def main(url: str) -> int: - asyncio.run(run(url)) - return 0 - - -if __name__ == "__main__": - main() diff --git a/examples/clients/simple-task-client/pyproject.toml b/examples/clients/simple-task-client/pyproject.toml deleted file mode 100644 index c7abf51159..0000000000 --- a/examples/clients/simple-task-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-client" -version = "0.1.0" -description = "A simple MCP client demonstrating task polling" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "client"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["click>=8.0", "mcp"] - -[project.scripts] -mcp-simple-task-client = "mcp_simple_task_client.main:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_client"] - -[tool.pyright] -include = ["mcp_simple_task_client"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/clients/simple-task-interactive-client/README.md b/examples/clients/simple-task-interactive-client/README.md deleted file mode 100644 index 3397d3b5d7..0000000000 --- a/examples/clients/simple-task-interactive-client/README.md +++ /dev/null @@ -1,87 +0,0 @@ -# Simple Interactive Task Client - -A minimal MCP client demonstrating responses to interactive tasks (elicitation and sampling). - -## Running - -First, start the interactive task server in another terminal: - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -Then run the client: - -```bash -cd examples/clients/simple-task-interactive-client -uv run mcp-simple-task-interactive-client -``` - -Use `--url` to connect to a different server. - -## What it does - -1. Connects to the server via streamable HTTP -2. Calls `confirm_delete` - server asks for confirmation, client responds via terminal -3. Calls `write_haiku` - server requests LLM completion, client returns a hardcoded haiku - -## Key concepts - -### Elicitation callback - -```python -async def elicitation_callback(context, params) -> ElicitResult: - # Handle user input request from server - return ElicitResult(action="accept", content={"confirm": True}) -``` - -### Sampling callback - -```python -async def sampling_callback(context, params) -> CreateMessageResult: - # Handle LLM completion request from server - return CreateMessageResult(model="...", role="assistant", content=...) -``` - -### Using call_tool_as_task - -```python -# Call a tool as a task (returns immediately with task reference) -result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) -task_id = result.task.task_id - -# Get result - this delivers elicitation/sampling requests and blocks until complete -final = await session.experimental.get_task_result(task_id, CallToolResult) -``` - -**Important**: The `get_task_result()` call is what triggers the delivery of elicitation -and sampling requests to your callbacks. It blocks until the task completes and returns -the final result. - -## Expected output - -```text -Available tools: ['confirm_delete', 'write_haiku'] - ---- Demo 1: Elicitation --- -Calling confirm_delete tool... -Task created: - -[Elicitation] Server asks: Are you sure you want to delete 'important.txt'? -Your response (y/n): y -[Elicitation] Responding with: confirm=True -Result: Deleted 'important.txt' - ---- Demo 2: Sampling --- -Calling write_haiku tool... -Task created: - -[Sampling] Server requests LLM completion for: Write a haiku about autumn leaves -[Sampling] Responding with haiku -Result: -Haiku: -Cherry blossoms fall -Softly on the quiet pond -Spring whispers goodbye -``` diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py deleted file mode 100644 index 2fc2cda8d9..0000000000 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .main import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py deleted file mode 100644 index ff5f499280..0000000000 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Simple interactive task client demonstrating elicitation and sampling responses. - -This example demonstrates the spec-compliant polling pattern: -1. Poll tasks/get watching for status changes -2. On input_required, call tasks/result to receive elicitation/sampling requests -3. Continue until terminal status, then retrieve final result -""" - -import asyncio - -import click -from mcp import ClientSession -from mcp.client.context import ClientRequestContext -from mcp.client.streamable_http import streamable_http_client -from mcp.types import ( - CallToolResult, - CreateMessageRequestParams, - CreateMessageResult, - ElicitRequestParams, - ElicitResult, - TextContent, -) - - -async def elicitation_callback( - context: ClientRequestContext, - params: ElicitRequestParams, -) -> ElicitResult: - """Handle elicitation requests from the server.""" - print(f"\n[Elicitation] Server asks: {params.message}") - - # Simple terminal prompt - response = input("Your response (y/n): ").strip().lower() - confirmed = response in ("y", "yes", "true", "1") - - print(f"[Elicitation] Responding with: confirm={confirmed}") - return ElicitResult(action="accept", content={"confirm": confirmed}) - - -async def sampling_callback( - context: ClientRequestContext, - params: CreateMessageRequestParams, -) -> CreateMessageResult: - """Handle sampling requests from the server.""" - # Get the prompt from the first message - prompt = "unknown" - if params.messages: - content = params.messages[0].content - if isinstance(content, TextContent): - prompt = content.text - - print(f"\n[Sampling] Server requests LLM completion for: {prompt}") - - # Return a hardcoded haiku (in real use, call your LLM here) - haiku = """Cherry blossoms fall -Softly on the quiet pond -Spring whispers goodbye""" - - print("[Sampling] Responding with haiku") - return CreateMessageResult( - model="mock-haiku-model", - role="assistant", - content=TextContent(type="text", text=haiku), - ) - - -def get_text(result: CallToolResult) -> str: - """Extract text from a CallToolResult.""" - if result.content and isinstance(result.content[0], TextContent): - return result.content[0].text - return "(no text)" - - -async def run(url: str) -> None: - async with streamable_http_client(url) as (read, write): - async with ClientSession( - read, - write, - elicitation_callback=elicitation_callback, - sampling_callback=sampling_callback, - ) as session: - await session.initialize() - - # List tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Demo 1: Elicitation (confirm_delete) - print("\n--- Demo 1: Elicitation ---") - print("Calling confirm_delete tool...") - - elicit_task = await session.experimental.call_tool_as_task("confirm_delete", {"filename": "important.txt"}) - elicit_task_id = elicit_task.task.task_id - print(f"Task created: {elicit_task_id}") - - # Poll until terminal, calling tasks/result on input_required - async for status in session.experimental.poll_task(elicit_task_id): - print(f"[Poll] Status: {status.status}") - if status.status == "input_required": - # Server needs input - tasks/result delivers the elicitation request - elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) - break - else: - # poll_task exited due to terminal status - elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) - - print(f"Result: {get_text(elicit_result)}") - - # Demo 2: Sampling (write_haiku) - print("\n--- Demo 2: Sampling ---") - print("Calling write_haiku tool...") - - sampling_task = await session.experimental.call_tool_as_task("write_haiku", {"topic": "autumn leaves"}) - sampling_task_id = sampling_task.task.task_id - print(f"Task created: {sampling_task_id}") - - # Poll until terminal, calling tasks/result on input_required - async for status in session.experimental.poll_task(sampling_task_id): - print(f"[Poll] Status: {status.status}") - if status.status == "input_required": - sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) - break - else: - sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) - - print(f"Result:\n{get_text(sampling_result)}") - - -@click.command() -@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") -def main(url: str) -> int: - asyncio.run(run(url)) - return 0 - - -if __name__ == "__main__": - main() diff --git a/examples/clients/simple-task-interactive-client/pyproject.toml b/examples/clients/simple-task-interactive-client/pyproject.toml deleted file mode 100644 index 47191573f2..0000000000 --- a/examples/clients/simple-task-interactive-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-interactive-client" -version = "0.1.0" -description = "A simple MCP client demonstrating interactive task responses" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "client", "elicitation", "sampling"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["click>=8.0", "mcp"] - -[project.scripts] -mcp-simple-task-interactive-client = "mcp_simple_task_interactive_client.main:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_interactive_client"] - -[tool.pyright] -include = ["mcp_simple_task_interactive_client"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task-interactive/README.md b/examples/servers/simple-task-interactive/README.md deleted file mode 100644 index b8f384cb48..0000000000 --- a/examples/servers/simple-task-interactive/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# Simple Interactive Task Server - -A minimal MCP server demonstrating interactive tasks with elicitation and sampling. - -## Running - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. - -## What it does - -This server exposes two tools: - -### `confirm_delete` (demonstrates elicitation) - -Asks the user for confirmation before "deleting" a file. - -- Uses `task.elicit()` to request user input -- Shows the elicitation flow: task -> input_required -> response -> complete - -### `write_haiku` (demonstrates sampling) - -Asks the LLM to write a haiku about a topic. - -- Uses `task.create_message()` to request LLM completion -- Shows the sampling flow: task -> input_required -> response -> complete - -## Usage with the client - -In one terminal, start the server: - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -In another terminal, run the interactive client: - -```bash -cd examples/clients/simple-task-interactive-client -uv run mcp-simple-task-interactive-client -``` - -## Expected server output - -When a client connects and calls the tools, you'll see: - -```text -Starting server on http://localhost:8000/mcp - -[Server] confirm_delete called for 'important.txt' -[Server] Task created: -[Server] Sending elicitation request to client... -[Server] Received elicitation response: action=accept, content={'confirm': True} -[Server] Completing task with result: Deleted 'important.txt' - -[Server] write_haiku called for topic 'autumn leaves' -[Server] Task created: -[Server] Sending sampling request to client... -[Server] Received sampling response: Cherry blossoms fall -Softly on the quiet pon... -[Server] Completing task with haiku -``` - -## Key concepts - -1. **ServerTaskContext**: Provides `elicit()` and `create_message()` for user interaction -2. **run_task()**: Spawns background work, auto-completes/fails, returns immediately -3. **TaskResultHandler**: Delivers queued messages and routes responses -4. **Response routing**: Responses are routed back to waiting resolvers diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py deleted file mode 100644 index e7ef16530b..0000000000 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .server import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py deleted file mode 100644 index bc06e12088..0000000000 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Simple interactive task server demonstrating elicitation and sampling. - -This example shows the simplified task API where: -- server.experimental.enable_tasks() sets up all infrastructure -- ctx.experimental.run_task() handles task lifecycle automatically -- ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly -""" - -from typing import Any - -import click -import uvicorn -from mcp import types -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext - - -async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None -) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="confirm_delete", - description="Asks for confirmation before deleting (demonstrates elicitation)", - input_schema={ - "type": "object", - "properties": {"filename": {"type": "string"}}, - }, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - types.Tool( - name="write_haiku", - description="Asks LLM to write a haiku (demonstrates sampling)", - input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - ] - ) - - -async def handle_confirm_delete(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the confirm_delete tool - demonstrates elicitation.""" - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - filename = arguments.get("filename", "unknown.txt") - print(f"\n[Server] confirm_delete called for '{filename}'") - - async def work(task: ServerTaskContext) -> types.CallToolResult: - print(f"[Server] Task {task.task_id} starting elicitation...") - - result = await task.elicit( - message=f"Are you sure you want to delete '{filename}'?", - requested_schema={ - "type": "object", - "properties": {"confirm": {"type": "boolean"}}, - "required": ["confirm"], - }, - ) - - print(f"[Server] Received elicitation response: action={result.action}, content={result.content}") - - if result.action == "accept" and result.content: - confirmed = result.content.get("confirm", False) - text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled" - else: - text = "Deletion cancelled" - - print(f"[Server] Completing task with result: {text}") - return types.CallToolResult(content=[types.TextContent(type="text", text=text)]) - - return await ctx.experimental.run_task(work) - - -async def handle_write_haiku(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the write_haiku tool - demonstrates sampling.""" - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - topic = arguments.get("topic", "nature") - print(f"\n[Server] write_haiku called for topic '{topic}'") - - async def work(task: ServerTaskContext) -> types.CallToolResult: - print(f"[Server] Task {task.task_id} starting sampling...") - - result = await task.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text=f"Write a haiku about {topic}"), - ) - ], - max_tokens=50, - ) - - haiku = "No response" - if isinstance(result.content, types.TextContent): - haiku = result.content.text - - print(f"[Server] Received sampling response: {haiku[:50]}...") - return types.CallToolResult(content=[types.TextContent(type="text", text=f"Haiku:\n{haiku}")]) - - return await ctx.experimental.run_task(work) - - -async def handle_call_tool( - ctx: ServerRequestContext, params: types.CallToolRequestParams -) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - arguments = params.arguments or {} - - if params.name == "confirm_delete": - return await handle_confirm_delete(ctx, arguments) - elif params.name == "write_haiku": - return await handle_write_haiku(ctx, arguments) - - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], - is_error=True, - ) - - -server = Server( - "simple-task-interactive", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, -) - -# Enable task support - this auto-registers all handlers -server.experimental.enable_tasks() - - -@click.command() -@click.option("--port", default=8000, help="Port to listen on") -def main(port: int) -> int: - starlette_app = server.streamable_http_app() - print(f"Starting server on http://localhost:{port}/mcp") - uvicorn.run(starlette_app, host="127.0.0.1", port=port) - return 0 diff --git a/examples/servers/simple-task-interactive/pyproject.toml b/examples/servers/simple-task-interactive/pyproject.toml deleted file mode 100644 index 4ec9770763..0000000000 --- a/examples/servers/simple-task-interactive/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-interactive" -version = "0.1.0" -description = "A simple MCP server demonstrating interactive tasks (elicitation & sampling)" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "elicitation", "sampling"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] - -[project.scripts] -mcp-simple-task-interactive = "mcp_simple_task_interactive.server:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_interactive"] - -[tool.pyright] -include = ["mcp_simple_task_interactive"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task/README.md b/examples/servers/simple-task/README.md deleted file mode 100644 index 6914e0414f..0000000000 --- a/examples/servers/simple-task/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# Simple Task Server - -A minimal MCP server demonstrating the experimental tasks feature over streamable HTTP. - -## Running - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. - -## What it does - -This server exposes a single tool `long_running_task` that: - -1. Must be called as a task (with `task` metadata in the request) -2. Takes ~3 seconds to complete -3. Sends status updates during execution -4. Returns a result when complete - -## Usage with the client - -In one terminal, start the server: - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -In another terminal, run the client: - -```bash -cd examples/clients/simple-task-client -uv run mcp-simple-task-client -``` diff --git a/examples/servers/simple-task/mcp_simple_task/__init__.py b/examples/servers/simple-task/mcp_simple_task/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/servers/simple-task/mcp_simple_task/__main__.py b/examples/servers/simple-task/mcp_simple_task/__main__.py deleted file mode 100644 index e7ef16530b..0000000000 --- a/examples/servers/simple-task/mcp_simple_task/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .server import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py deleted file mode 100644 index 7583cd8f0e..0000000000 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Simple task server demonstrating MCP tasks over streamable HTTP.""" - -import anyio -import click -import uvicorn -from mcp import types -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext - - -async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None -) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="long_running_task", - description="A task that takes a few seconds to complete with status updates", - input_schema={"type": "object", "properties": {}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ) - ] - ) - - -async def handle_call_tool( - ctx: ServerRequestContext, params: types.CallToolRequestParams -) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - if params.name == "long_running_task": - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> types.CallToolResult: - await task.update_status("Starting work...") - await anyio.sleep(1) - - await task.update_status("Processing step 1...") - await anyio.sleep(1) - - await task.update_status("Processing step 2...") - await anyio.sleep(1) - - return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) - - return await ctx.experimental.run_task(work) - - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], - is_error=True, - ) - - -server = Server( - "simple-task-server", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, -) - -# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task -server.experimental.enable_tasks() - - -@click.command() -@click.option("--port", default=8000, help="Port to listen on") -def main(port: int) -> int: - starlette_app = server.streamable_http_app() - - print(f"Starting server on http://localhost:{port}/mcp") - uvicorn.run(starlette_app, host="127.0.0.1", port=port) - return 0 diff --git a/examples/servers/simple-task/pyproject.toml b/examples/servers/simple-task/pyproject.toml deleted file mode 100644 index 921a1c34fc..0000000000 --- a/examples/servers/simple-task/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task" -version = "0.1.0" -description = "A simple MCP server demonstrating tasks" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] - -[project.scripts] -mcp-simple-task = "mcp_simple_task.server:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task"] - -[tool.pyright] -include = ["mcp_simple_task"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/mkdocs.yml b/mkdocs.yml index e48c64242d..cb89faf0f0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,12 +19,6 @@ nav: - Low-Level Server: low-level-server.md - Authorization: authorization.md - Testing: testing.md - - Experimental: - - Overview: experimental/index.md - - Tasks: - - Introduction: experimental/tasks.md - - Server Implementation: experimental/tasks-server.md - - Client Usage: experimental/tasks-client.md - API Reference: api/ theme: diff --git a/src/mcp/client/experimental/__init__.py b/src/mcp/client/experimental/__init__.py deleted file mode 100644 index 8d74cb3044..0000000000 --- a/src/mcp/client/experimental/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Experimental client features. - -WARNING: These APIs are experimental and may change without notice. -""" - -from mcp.client.experimental.tasks import ExperimentalClientFeatures - -__all__ = ["ExperimentalClientFeatures"] diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py deleted file mode 100644 index 0ab513236a..0000000000 --- a/src/mcp/client/experimental/task_handlers.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Experimental task handler protocols for server -> client requests. - -This module provides Protocol types and default handlers for when servers -send task-related requests to clients (the reverse of normal client -> server flow). - -WARNING: These APIs are experimental and may change without notice. - -Use cases: -- Server sends task-augmented sampling/elicitation request to client -- Client creates a local task, spawns background work, returns CreateTaskResult -- Server polls client's task status via tasks/get, tasks/result, etc. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Protocol - -from pydantic import TypeAdapter - -from mcp import types -from mcp.shared._context import RequestContext -from mcp.shared.session import RequestResponder - -if TYPE_CHECKING: - from mcp.client.session import ClientSession - - -class GetTaskHandlerFnT(Protocol): - """Handler for tasks/get requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.GetTaskRequestParams, - ) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch - - -class GetTaskResultHandlerFnT(Protocol): - """Handler for tasks/result requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.GetTaskPayloadRequestParams, - ) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch - - -class ListTasksHandlerFnT(Protocol): - """Handler for tasks/list requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch - - -class CancelTaskHandlerFnT(Protocol): - """Handler for tasks/cancel requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch - - -class TaskAugmentedSamplingFnT(Protocol): - """Handler for task-augmented sampling/createMessage requests from server. - - When server sends a CreateMessageRequest with task field, this callback - is invoked. The callback should create a task, spawn background work, - and return CreateTaskResult immediately. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch - - -class TaskAugmentedElicitationFnT(Protocol): - """Handler for task-augmented elicitation/create requests from server. - - When server sends an ElicitRequest with task field, this callback - is invoked. The callback should create a task, spawn background work, - and return CreateTaskResult immediately. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.ElicitRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch - - -async def default_get_task_handler( - context: RequestContext[ClientSession], - params: types.GetTaskRequestParams, -) -> types.GetTaskResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/get not supported", - ) - - -async def default_get_task_result_handler( - context: RequestContext[ClientSession], - params: types.GetTaskPayloadRequestParams, -) -> types.GetTaskPayloadResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/result not supported", - ) - - -async def default_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, -) -> types.ListTasksResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/list not supported", - ) - - -async def default_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, -) -> types.CancelTaskResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/cancel not supported", - ) - - -async def default_task_augmented_sampling( - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, -) -> types.CreateTaskResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Task-augmented sampling not supported", - ) - - -async def default_task_augmented_elicitation( - context: RequestContext[ClientSession], - params: types.ElicitRequestParams, - task_metadata: types.TaskMetadata, -) -> types.CreateTaskResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Task-augmented elicitation not supported", - ) - - -@dataclass -class ExperimentalTaskHandlers: - """Container for experimental task handlers. - - Groups all task-related handlers that handle server -> client requests. - This includes both pure task requests (get, list, cancel, result) and - task-augmented request handlers (sampling, elicitation with task field). - - WARNING: These APIs are experimental and may change without notice. - - Example: - ```python - handlers = ExperimentalTaskHandlers( - get_task=my_get_task_handler, - list_tasks=my_list_tasks_handler, - ) - session = ClientSession(..., experimental_task_handlers=handlers) - ``` - """ - - # Pure task request handlers - get_task: GetTaskHandlerFnT = field(default=default_get_task_handler) - get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler) - list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler) - cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler) - - # Task-augmented request handlers - augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling) - augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation) - - def build_capability(self) -> types.ClientTasksCapability | None: - """Build ClientTasksCapability from the configured handlers. - - Returns a capability object that reflects which handlers are configured - (i.e., not using the default "not supported" handlers). - - Returns: - ClientTasksCapability if any handlers are provided, None otherwise - """ - has_list = self.list_tasks is not default_list_tasks_handler - has_cancel = self.cancel_task is not default_cancel_task_handler - has_sampling = self.augmented_sampling is not default_task_augmented_sampling - has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation - - # If no handlers are provided, return None - if not any([has_list, has_cancel, has_sampling, has_elicitation]): - return None - - # Build requests capability if any request handlers are provided - requests_capability: types.ClientTasksRequestsCapability | None = None - if has_sampling or has_elicitation: - requests_capability = types.ClientTasksRequestsCapability( - sampling=types.TasksSamplingCapability(create_message=types.TasksCreateMessageCapability()) - if has_sampling - else None, - elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability()) - if has_elicitation - else None, - ) - - return types.ClientTasksCapability( - list=types.TasksListCapability() if has_list else None, - cancel=types.TasksCancelCapability() if has_cancel else None, - requests=requests_capability, - ) - - @staticmethod - def handles_request(request: types.ServerRequest) -> bool: - """Check if this handler handles the given request type.""" - return isinstance( - request, - types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest, - ) - - async def handle_request( - self, - ctx: RequestContext[ClientSession], - responder: RequestResponder[types.ServerRequest, types.ClientResult], - ) -> None: - """Handle a task-related request from the server. - - Call handles_request() first to check if this handler can handle the request. - """ - client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( - types.ClientResult | types.ErrorData - ) - - match responder.request: - case types.GetTaskRequest(params=params): - response = await self.get_task(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.GetTaskPayloadRequest(params=params): - response = await self.get_task_result(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.ListTasksRequest(params=params): - response = await self.list_tasks(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.CancelTaskRequest(params=params): - response = await self.cancel_task(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case _: # pragma: no cover - raise ValueError(f"Unhandled request type: {type(responder.request)}") - - -# Backwards compatibility aliases -default_task_augmented_sampling_callback = default_task_augmented_sampling -default_task_augmented_elicitation_callback = default_task_augmented_elicitation diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py deleted file mode 100644 index a566df766b..0000000000 --- a/src/mcp/client/experimental/tasks.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Experimental client-side task support. - -This module provides client methods for interacting with MCP tasks. - -WARNING: These APIs are experimental and may change without notice. - -Example: - ```python - # Call a tool as a task - result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) - task_id = result.task.task_id - - # Get task status - status = await session.experimental.get_task(task_id) - - # Get task result when complete - if status.status == "completed": - result = await session.experimental.get_task_result(task_id, CallToolResult) - - # List all tasks - tasks = await session.experimental.list_tasks() - - # Cancel a task - await session.experimental.cancel_task(task_id) - ``` -""" - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, TypeVar - -from mcp import types -from mcp.shared.experimental.tasks.polling import poll_until_terminal -from mcp.types._types import RequestParamsMeta - -if TYPE_CHECKING: - from mcp.client.session import ClientSession - -ResultT = TypeVar("ResultT", bound=types.Result) - - -class ExperimentalClientFeatures: - """Experimental client features for tasks and other experimental APIs. - - WARNING: These APIs are experimental and may change without notice. - - Access via session.experimental: - status = await session.experimental.get_task(task_id) - """ - - def __init__(self, session: "ClientSession") -> None: - self._session = session - - async def call_tool_as_task( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - ttl: int = 60000, - meta: RequestParamsMeta | None = None, - ) -> types.CreateTaskResult: - """Call a tool as a task, returning a CreateTaskResult for polling. - - This is a convenience method for calling tools that support task execution. - The server will return a task reference instead of the immediate result, - which can then be polled via `get_task()` and retrieved via `get_task_result()`. - - Args: - name: The tool name - arguments: Tool arguments - ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute) - meta: Optional metadata to include in the request - - Returns: - CreateTaskResult containing the task reference - - Example: - ```python - # Create task - result = await session.experimental.call_tool_as_task( - "long_running_tool", {"input": "data"} - ) - task_id = result.task.task_id - - # Poll for completion - while True: - status = await session.experimental.get_task(task_id) - if status.status == "completed": - break - await anyio.sleep(0.5) - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - return await self._session.send_request( - types.CallToolRequest( - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - task=types.TaskMetadata(ttl=ttl), - _meta=meta, - ), - ), - types.CreateTaskResult, - ) - - async def get_task(self, task_id: str) -> types.GetTaskResult: - """Get the current status of a task. - - Args: - task_id: The task identifier - - Returns: - GetTaskResult containing the task status and metadata - """ - return await self._session.send_request( - types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), - types.GetTaskResult, - ) - - async def get_task_result( - self, - task_id: str, - result_type: type[ResultT], - ) -> ResultT: - """Get the result of a completed task. - - The result type depends on the original request type: - - tools/call tasks return CallToolResult - - Other request types return their corresponding result type - - Args: - task_id: The task identifier - result_type: The expected result type (e.g., CallToolResult) - - Returns: - The task result, validated against result_type - """ - return await self._session.send_request( - types.GetTaskPayloadRequest( - params=types.GetTaskPayloadRequestParams(task_id=task_id), - ), - result_type, - ) - - async def list_tasks( - self, - cursor: str | None = None, - ) -> types.ListTasksResult: - """List all tasks. - - Args: - cursor: Optional pagination cursor - - Returns: - ListTasksResult containing tasks and optional next cursor - """ - params = types.PaginatedRequestParams(cursor=cursor) if cursor else None - return await self._session.send_request( - types.ListTasksRequest(params=params), - types.ListTasksResult, - ) - - async def cancel_task(self, task_id: str) -> types.CancelTaskResult: - """Cancel a running task. - - Args: - task_id: The task identifier - - Returns: - CancelTaskResult with the updated task state - """ - return await self._session.send_request( - types.CancelTaskRequest( - params=types.CancelTaskRequestParams(task_id=task_id), - ), - types.CancelTaskResult, - ) - - async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: - """Poll a task until it reaches a terminal status. - - Yields GetTaskResult for each poll, allowing the caller to react to - status changes (e.g., handle input_required). Exits when the task reaches - a terminal status (completed, failed, cancelled). - - Respects the pollInterval hint from the server. - - Args: - task_id: The task identifier - - Yields: - GetTaskResult for each poll - - Example: - ```python - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - if status.status == "input_required": - # Handle elicitation request via tasks/result - pass - - # Task is now terminal, get the result - result = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - async for status in poll_until_terminal(self.get_task, task_id): - yield status diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 86113874be..08f532eca5 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,8 +8,6 @@ from mcp import types from mcp.client._transport import ReadStream, WriteStream -from mcp.client.experimental import ExperimentalClientFeatures -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -120,7 +118,6 @@ def __init__( client_info: types.Implementation | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, - experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) self._client_info = client_info or DEFAULT_CLIENT_INFO @@ -132,10 +129,6 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None - self._experimental_features: ExperimentalClientFeatures | None = None - - # Experimental: Task handlers (use defaults if not provided) - self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() @property def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: @@ -174,7 +167,6 @@ async def initialize(self) -> types.InitializeResult: elicitation=elicitation, experimental=None, roots=roots, - tasks=self._task_handlers.build_capability(), ), client_info=self._client_info, ), @@ -199,23 +191,6 @@ def initialize_result(self) -> types.InitializeResult | None: """ return self._initialize_result - @property - def experimental(self) -> ExperimentalClientFeatures: - """Experimental APIs for tasks and other features. - - !!! warning - These APIs are experimental and may change without notice. - - Example: - ```python - status = await session.experimental.get_task(task_id) - result = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - if self._experimental_features is None: - self._experimental_features = ExperimentalClientFeatures(self) - return self._experimental_features - async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a ping request.""" return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult) @@ -413,31 +388,16 @@ async def send_roots_list_changed(self) -> None: async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self) - # Delegate to experimental task handler if applicable - if self._task_handlers.handles_request(responder.request): - with responder: - await self._task_handlers.handle_request(ctx, responder) - return None - - # Core request handling match responder.request: case types.CreateMessageRequest(params=params): with responder: - # Check if this is a task-augmented request - if params.task is not None: - response = await self._task_handlers.augmented_sampling(ctx, params, params.task) - else: - response = await self._sampling_callback(ctx, params) + response = await self._sampling_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) case types.ElicitRequest(params=params): with responder: - # Check if this is a task-augmented request - if params.task is not None: - response = await self._task_handlers.augmented_elicitation(ctx, params, params.task) - else: - response = await self._elicitation_callback(ctx, params) + response = await self._elicitation_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) @@ -447,14 +407,9 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): + case types.PingRequest(): # pragma: no branch with responder: - return await responder.respond(types.EmptyResult()) - - case _: # pragma: no cover - pass # Task requests handled above by _task_handlers - - return None + await responder.respond(types.EmptyResult()) async def _handle_incoming( self, diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d8e11d78b2..bc54c5d2eb 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -5,7 +5,6 @@ from typing_extensions import TypeVar -from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback @@ -17,7 +16,6 @@ @dataclass(kw_only=True) class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]): lifespan_context: LifespanContextT - experimental: Experimental request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/server/experimental/__init__.py b/src/mcp/server/experimental/__init__.py deleted file mode 100644 index fd1db623f2..0000000000 --- a/src/mcp/server/experimental/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Server-side experimental features. - -WARNING: These APIs are experimental and may change without notice. - -Import directly from submodules: -- mcp.server.experimental.task_context.ServerTaskContext -- mcp.server.experimental.task_support.TaskSupport -- mcp.server.experimental.task_result_handler.TaskResultHandler -- mcp.server.experimental.request_context.Experimental -""" diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py deleted file mode 100644 index 3eba65822a..0000000000 --- a/src/mcp/server/experimental/request_context.py +++ /dev/null @@ -1,217 +0,0 @@ -"""Experimental request context features. - -This module provides the Experimental class which gives access to experimental -features within a request context, such as task-augmented request handling. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field -from typing import Any - -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal -from mcp.types import ( - METHOD_NOT_FOUND, - TASK_FORBIDDEN, - TASK_REQUIRED, - ClientCapabilities, - CreateTaskResult, - ErrorData, - Result, - TaskExecutionMode, - TaskMetadata, - Tool, -) - - -@dataclass -class Experimental: - """Experimental features context for task-augmented requests. - - Provides helpers for validating task execution compatibility and - running tasks with automatic lifecycle management. - - WARNING: This API is experimental and may change without notice. - """ - - task_metadata: TaskMetadata | None = None - _client_capabilities: ClientCapabilities | None = field(default=None, repr=False) - _session: ServerSession | None = field(default=None, repr=False) - _task_support: TaskSupport | None = field(default=None, repr=False) - - @property - def is_task(self) -> bool: - """Check if this request is task-augmented.""" - return self.task_metadata is not None - - @property - def client_supports_tasks(self) -> bool: - """Check if the client declared task support.""" - if self._client_capabilities is None: - return False - return self._client_capabilities.tasks is not None - - def validate_task_mode( - self, tool_task_mode: TaskExecutionMode | None, *, raise_error: bool = True - ) -> ErrorData | None: - """Validate that the request is compatible with the tool's task execution mode. - - Per MCP spec: - - "required": Clients MUST invoke as a task. Server returns -32601 if not. - - "forbidden" (or None): Clients MUST NOT invoke as a task. Server returns -32601 if they do. - - "optional": Either is acceptable. - - Args: - tool_task_mode: The tool's execution.taskSupport value - ("forbidden", "optional", "required", or None) - raise_error: If True, raises MCPError on validation failure. If False, returns ErrorData. - - Returns: - None if valid, ErrorData if invalid and raise_error=False - - Raises: - MCPError: If invalid and raise_error=True - """ - - mode = tool_task_mode or TASK_FORBIDDEN - - error: ErrorData | None = None - - if mode == TASK_REQUIRED and not self.is_task: - error = ErrorData(code=METHOD_NOT_FOUND, message="This tool requires task-augmented invocation") - elif mode == TASK_FORBIDDEN and self.is_task: - error = ErrorData(code=METHOD_NOT_FOUND, message="This tool does not support task-augmented invocation") - - if error is not None and raise_error: - raise MCPError.from_error_data(error) - - return error - - def validate_for_tool(self, tool: Tool, *, raise_error: bool = True) -> ErrorData | None: - """Validate that the request is compatible with the given tool. - - Convenience wrapper around validate_task_mode that extracts the mode from a Tool. - - Args: - tool: The Tool definition - raise_error: If True, raises MCPError on validation failure. - - Returns: - None if valid, ErrorData if invalid and raise_error=False - """ - mode = tool.execution.task_support if tool.execution else None - return self.validate_task_mode(mode, raise_error=raise_error) - - def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: - """Check if this client can use a tool with the given task mode. - - Useful for filtering tool lists or providing warnings. - Returns False if the tool's task mode is "required" but the client doesn't support tasks. - - Args: - tool_task_mode: The tool's execution.taskSupport value - - Returns: - True if the client can use this tool, False otherwise - """ - mode = tool_task_mode or TASK_FORBIDDEN - if mode == TASK_REQUIRED and not self.client_supports_tasks: - return False - return True - - async def run_task( - self, - work: Callable[[ServerTaskContext], Awaitable[Result]], - *, - task_id: str | None = None, - model_immediate_response: str | None = None, - ) -> CreateTaskResult: - """Create a task, spawn background work, and return CreateTaskResult immediately. - - This is the recommended way to handle task-augmented tool calls. It: - 1. Creates a task in the store - 2. Spawns the work function in a background task - 3. Returns CreateTaskResult immediately - - The work function receives a ServerTaskContext with: - - elicit() for sending elicitation requests - - create_message() for sampling requests - - update_status() for progress updates - - complete()/fail() for finishing the task - - When work() returns a Result, the task is auto-completed with that result. - If work() raises an exception, the task is auto-failed. - - Args: - work: Async function that does the actual work - task_id: Optional task ID (generated if not provided) - model_immediate_response: Optional string to include in _meta as - io.modelcontextprotocol/model-immediate-response - - Returns: - CreateTaskResult to return to the client - - Raises: - RuntimeError: If task support is not enabled or task_metadata is missing - - Example: - ```python - async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: - async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit( - message="Are you sure?", - requested_schema={"type": "object", ...} - ) - confirmed = result.content.get("confirm", False) - return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) - - return await ctx.experimental.run_task(work) - ``` - - WARNING: This API is experimental and may change without notice. - """ - if self._task_support is None: - raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.") - if self._session is None: - raise RuntimeError("Session not available.") - if self.task_metadata is None: - raise RuntimeError( - "Request is not task-augmented (no task field in params). " - "The client must send a task-augmented request." - ) - - support = self._task_support - # Access task_group via TaskSupport - raises if not in run() context - task_group = support.task_group - - task = await support.store.create_task(self.task_metadata, task_id) - - task_ctx = ServerTaskContext( - task=task, - store=support.store, - session=self._session, - queue=support.queue, - handler=support.handler, - ) - - async def execute() -> None: - try: - result = await work(task_ctx) - if not is_terminal(task_ctx.task.status): - await task_ctx.complete(result) - except Exception as e: - if not is_terminal(task_ctx.task.status): - await task_ctx.fail(str(e)) - - task_group.start_soon(execute) - - meta: dict[str, Any] | None = None - if model_immediate_response is not None: - meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response} - - return CreateTaskResult(task=task, **{"_meta": meta} if meta else {}) diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py deleted file mode 100644 index 2f9d1b0320..0000000000 --- a/src/mcp/server/experimental/session_features.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Experimental server session features for server→client task operations. - -This module provides the server-side equivalent of ExperimentalClientFeatures, -allowing the server to send task-augmented requests to the client and poll for results. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, TypeVar - -from mcp import types -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.experimental.tasks.capabilities import ( - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.shared.experimental.tasks.polling import poll_until_terminal - -if TYPE_CHECKING: - from mcp.server.session import ServerSession - -ResultT = TypeVar("ResultT", bound=types.Result) - - -class ExperimentalServerSessionFeatures: - """Experimental server session features for server→client task operations. - - This provides the server-side equivalent of ExperimentalClientFeatures, - allowing the server to send task-augmented requests to the client and - poll for results. - - WARNING: These APIs are experimental and may change without notice. - - Access via session.experimental: - result = await session.experimental.elicit_as_task(...) - """ - - def __init__(self, session: "ServerSession") -> None: - self._session = session - - async def get_task(self, task_id: str) -> types.GetTaskResult: - """Send tasks/get to the client to get task status. - - Args: - task_id: The task identifier - - Returns: - GetTaskResult containing the task status - """ - return await self._session.send_request( - types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), - types.GetTaskResult, - ) - - async def get_task_result( - self, - task_id: str, - result_type: type[ResultT], - ) -> ResultT: - """Send tasks/result to the client to retrieve the final result. - - Args: - task_id: The task identifier - result_type: The expected result type - - Returns: - The task result, validated against result_type - """ - return await self._session.send_request( - types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id)), - result_type, - ) - - async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: - """Poll a client task until it reaches terminal status. - - Yields GetTaskResult for each poll, allowing the caller to react to - status changes. Exits when task reaches a terminal status. - - Respects the pollInterval hint from the client. - - Args: - task_id: The task identifier - - Yields: - GetTaskResult for each poll - """ - async for status in poll_until_terminal(self.get_task, task_id): - yield status - - async def elicit_as_task( - self, - message: str, - requested_schema: types.ElicitRequestedSchema, - *, - ttl: int = 60000, - ) -> types.ElicitResult: - """Send a task-augmented elicitation to the client and poll until complete. - - The client will create a local task, process the elicitation asynchronously, - and return the result when ready. This method handles the full flow: - 1. Send elicitation with task field - 2. Receive CreateTaskResult from client - 3. Poll client's task until terminal - 4. Retrieve and return the final ElicitResult - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response - ttl: Task time-to-live in milliseconds - - Returns: - The client's elicitation response - - Raises: - MCPError: If client doesn't support task-augmented elicitation - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_elicitation(client_caps) - - create_result = await self._session.send_request( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=types.TaskMetadata(ttl=ttl), - ) - ), - types.CreateTaskResult, - ) - - task_id = create_result.task.task_id - - async for _ in self.poll_task(task_id): - pass - - return await self.get_task_result(task_id, types.ElicitResult) - - async def create_message_as_task( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - ttl: int = 60000, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - ) -> types.CreateMessageResult: - """Send a task-augmented sampling request and poll until complete. - - The client will create a local task, process the sampling request - asynchronously, and return the result when ready. - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - ttl: Task time-to-live in milliseconds - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support task-augmented sampling or tools - ValueError: If tool_use or tool_result message structure is invalid - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_sampling(client_caps) - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - create_result = await self._session.send_request( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=types.TaskMetadata(ttl=ttl), - ) - ), - types.CreateTaskResult, - ) - - task_id = create_result.task.task_id - - async for _ in self.poll_task(task_id): - pass - - return await self.get_task_result(task_id, types.CreateMessageResult) diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py deleted file mode 100644 index 1fc45badfd..0000000000 --- a/src/mcp/server/experimental/task_context.py +++ /dev/null @@ -1,587 +0,0 @@ -"""ServerTaskContext - Server-integrated task context with elicitation and sampling. - -This wraps the pure TaskContext and adds server-specific functionality: -- Elicitation (task.elicit()) -- Sampling (task.create_message()) -- Status notifications -""" - -from typing import Any - -import anyio - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.server.session import ServerSession -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.capabilities import ( - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_REQUEST, - TASK_STATUS_INPUT_REQUIRED, - TASK_STATUS_WORKING, - ClientCapabilities, - CreateMessageResult, - CreateTaskResult, - ElicitationCapability, - ElicitRequestedSchema, - ElicitResult, - IncludeContext, - ModelPreferences, - RequestId, - Result, - SamplingCapability, - SamplingMessage, - Task, - TaskMetadata, - TaskStatusNotification, - TaskStatusNotificationParams, - Tool, - ToolChoice, -) - - -class ServerTaskContext: - """Server-integrated task context with elicitation and sampling. - - This wraps a pure TaskContext and adds server-specific functionality: - - elicit() for sending elicitation requests to the client - - create_message() for sampling requests - - Status notifications via the session - - Example: - ```python - async def my_task_work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Starting...") - - result = await task.elicit( - message="Continue?", - requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}} - ) - - if result.content.get("ok"): - return CallToolResult(content=[TextContent(text="Done!")]) - else: - return CallToolResult(content=[TextContent(text="Cancelled")]) - ``` - """ - - def __init__( - self, - *, - task: Task, - store: TaskStore, - session: ServerSession, - queue: TaskMessageQueue, - handler: TaskResultHandler | None = None, - ): - """Create a ServerTaskContext. - - Args: - task: The Task object - store: The task store - session: The server session - queue: The message queue for elicitation/sampling - handler: The result handler for response routing (required for elicit/create_message) - """ - self._ctx = TaskContext(task=task, store=store) - self._session = session - self._queue = queue - self._handler = handler - self._store = store - - # Delegate pure properties to inner context - - @property - def task_id(self) -> str: - """The task identifier.""" - return self._ctx.task_id - - @property - def task(self) -> Task: - """The current task state.""" - return self._ctx.task - - @property - def is_cancelled(self) -> bool: - """Whether cancellation has been requested.""" - return self._ctx.is_cancelled - - def request_cancellation(self) -> None: - """Request cancellation of this task.""" - self._ctx.request_cancellation() - - # Enhanced methods with notifications - - async def update_status(self, message: str, *, notify: bool = True) -> None: - """Update the task's status message. - - Args: - message: The new status message - notify: Whether to send a notification to the client - """ - await self._ctx.update_status(message) - if notify: - await self._send_notification() - - async def complete(self, result: Result, *, notify: bool = True) -> None: - """Mark the task as completed with the given result. - - Args: - result: The task result - notify: Whether to send a notification to the client - """ - await self._ctx.complete(result) - if notify: - await self._send_notification() - - async def fail(self, error: str, *, notify: bool = True) -> None: - """Mark the task as failed with an error message. - - Args: - error: The error message - notify: Whether to send a notification to the client - """ - await self._ctx.fail(error) - if notify: - await self._send_notification() - - async def _send_notification(self) -> None: - """Send a task status notification to the client.""" - task = self._ctx.task - await self._session.send_notification( - TaskStatusNotification( - params=TaskStatusNotificationParams( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - ) - ) - - # Server-specific methods: elicitation and sampling - - def _check_elicitation_capability(self) -> None: - """Check if the client supports elicitation.""" - if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())): - raise MCPError(code=INVALID_REQUEST, message="Client does not support elicitation capability") - - def _check_sampling_capability(self) -> None: - """Check if the client supports sampling.""" - if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())): - raise MCPError(code=INVALID_REQUEST, message="Client does not support sampling capability") - - async def elicit( - self, - message: str, - requested_schema: ElicitRequestedSchema, - ) -> ElicitResult: - """Send an elicitation request via the task message queue. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the elicitation request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - - Returns: - The client's response - - Raises: - MCPError: If client doesn't support elicitation capability - """ - self._check_elicitation_capability() - - if self._handler is None: - raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] - message=message, - requested_schema=requested_schema, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return ElicitResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): - # This path is tested in test_elicit_restores_status_on_cancellation - # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - ) -> ElicitResult: - """Send a URL mode elicitation request via the task message queue. - - This directs the user to an external URL for out-of-band interactions - like OAuth flows, credential collection, or payment processing. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the elicitation request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - The client's response indicating acceptance, decline, or cancellation - - Raises: - MCPError: If client doesn't support elicitation capability - RuntimeError: If handler is not configured - """ - self._check_elicitation_capability() - - if self._handler is None: - raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage] - message=message, - url=url, - elicitation_id=elicitation_id, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return ElicitResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def create_message( - self, - messages: list[SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - tools: list[Tool] | None = None, - tool_choice: ToolChoice | None = None, - ) -> CreateMessageResult: - """Send a sampling request via the task message queue. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the sampling request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support sampling capability or tools - ValueError: If tool_use or tool_result message structure is invalid - """ - self._check_sampling_capability() - client_caps = self._session.client_params.capabilities if self._session.client_params else None - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - if self._handler is None: - raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return CreateMessageResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): - # This path is tested in test_create_message_restores_status_on_cancellation - # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def elicit_as_task( - self, - message: str, - requested_schema: ElicitRequestedSchema, - *, - ttl: int = 60000, - ) -> ElicitResult: - """Send a task-augmented elicitation via the queue, then poll client. - - This is for use inside a task-augmented tool call when you want the client - to handle the elicitation as its own task. The elicitation request is queued - and delivered when the client calls tasks/result. After the client responds - with CreateTaskResult, we poll the client's task until complete. - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - ttl: Task time-to-live in milliseconds for the client's task - - Returns: - The client's elicitation response - - Raises: - MCPError: If client doesn't support task-augmented elicitation - RuntimeError: If handler is not configured - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_elicitation(client_caps) - - if self._handler is None: - raise RuntimeError("handler is required for elicit_as_task()") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] - message=message, - requested_schema=requested_schema, - related_task_id=self.task_id, - task=TaskMetadata(ttl=ttl), - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for initial response (CreateTaskResult from client) - response_data = await resolver.wait() - create_result = CreateTaskResult.model_validate(response_data) - client_task_id = create_result.task.task_id - - # Poll the client's task using session.experimental - async for _ in self._session.experimental.poll_task(client_task_id): - pass - - # Get final result from client - result = await self._session.experimental.get_task_result( - client_task_id, - ElicitResult, - ) - - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return result - - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def create_message_as_task( - self, - messages: list[SamplingMessage], - *, - max_tokens: int, - ttl: int = 60000, - system_prompt: str | None = None, - include_context: IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - tools: list[Tool] | None = None, - tool_choice: ToolChoice | None = None, - ) -> CreateMessageResult: - """Send a task-augmented sampling request via the queue, then poll client. - - This is for use inside a task-augmented tool call when you want the client - to handle the sampling as its own task. The request is queued and delivered - when the client calls tasks/result. After the client responds with - CreateTaskResult, we poll the client's task until complete. - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - ttl: Task time-to-live in milliseconds for the client's task - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support task-augmented sampling or tools - ValueError: If tool_use or tool_result message structure is invalid - RuntimeError: If handler is not configured - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_sampling(client_caps) - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - if self._handler is None: - raise RuntimeError("handler is required for create_message_as_task()") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build request WITH task field for task-augmented sampling - request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - related_task_id=self.task_id, - task=TaskMetadata(ttl=ttl), - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for initial response (CreateTaskResult from client) - response_data = await resolver.wait() - create_result = CreateTaskResult.model_validate(response_data) - client_task_id = create_result.task.task_id - - # Poll the client's task using session.experimental - async for _ in self._session.experimental.poll_task(client_task_id): - pass - - # Get final result from client - result = await self._session.experimental.get_task_result( - client_task_id, - CreateMessageResult, - ) - - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return result - - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py deleted file mode 100644 index b2268bc1c8..0000000000 --- a/src/mcp/server/experimental/task_result_handler.py +++ /dev/null @@ -1,218 +0,0 @@ -"""TaskResultHandler - Integrated handler for tasks/result endpoint. - -This implements the dequeue-send-wait pattern from the MCP Tasks spec: -1. Dequeue all pending messages for the task -2. Send them to the client via transport with relatedRequestId routing -3. Wait if task is not in terminal state -4. Return final result when task completes - -This is the core of the task message queue pattern. -""" - -import logging -from typing import Any - -import anyio - -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal -from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.types import ( - INVALID_PARAMS, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadResult, - RelatedTaskMetadata, - RequestId, -) - -logger = logging.getLogger(__name__) - - -class TaskResultHandler: - """Handler for tasks/result that implements the message queue pattern. - - This handler: - 1. Dequeues pending messages (elicitations, notifications) for the task - 2. Sends them to the client via the response stream - 3. Waits for responses and resolves them back to callers - 4. Blocks until task reaches terminal state - 5. Returns the final result - - Usage: - async def handle_task_result( - ctx: ServerRequestContext, params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - ... - - server.experimental.enable_tasks( - on_task_result=handle_task_result, - ) - """ - - def __init__( - self, - store: TaskStore, - queue: TaskMessageQueue, - ): - self._store = store - self._queue = queue - # Map from internal request ID to resolver for routing responses - self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {} - - async def send_message( - self, - session: ServerSession, - message: SessionMessage, - ) -> None: - """Send a message via the session. - - This is a helper for delivering queued task messages. - """ - await session.send_message(message) - - async def handle( - self, - request: GetTaskPayloadRequest, - session: ServerSession, - request_id: RequestId, - ) -> GetTaskPayloadResult: - """Handle a tasks/result request. - - This implements the dequeue-send-wait loop: - 1. Dequeue all pending messages - 2. Send each via transport with relatedRequestId = this request's ID - 3. If task not terminal, wait for status change - 4. Loop until task is terminal - 5. Return final result - - Args: - request: The GetTaskPayloadRequest - session: The server session for sending messages - request_id: The request ID for relatedRequestId routing - - Returns: - GetTaskPayloadResult with the task's final payload - """ - task_id = request.params.task_id - - while True: - task = await self._store.get_task(task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") - - await self._deliver_queued_messages(task_id, session, request_id) - - # If task is terminal, return result - if is_terminal(task.status): - result = await self._store.get_result(task_id) - # GetTaskPayloadResult is a Result with extra="allow" - # The stored result contains the actual payload data - # Per spec: tasks/result MUST include _meta with related-task metadata - related_task = RelatedTaskMetadata(task_id=task_id) - related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)} - if result is not None: - result_data = result.model_dump(by_alias=True) - existing_meta: dict[str, Any] = result_data.get("_meta") or {} - result_data["_meta"] = {**existing_meta, **related_task_meta} - return GetTaskPayloadResult.model_validate(result_data) - return GetTaskPayloadResult.model_validate({"_meta": related_task_meta}) - - # Wait for task update (status change or new messages) - await self._wait_for_task_update(task_id) - - async def _deliver_queued_messages( - self, - task_id: str, - session: ServerSession, - request_id: RequestId, - ) -> None: - """Dequeue and send all pending messages for a task. - - Each message is sent via the session's write stream with - relatedRequestId set so responses route back to this stream. - """ - while True: - message = await self._queue.dequeue(task_id) - if message is None: - break - - # If this is a request (not notification), wait for response - if message.type == "request" and message.resolver is not None: - # Store the resolver so we can route the response back - original_id = message.original_request_id - if original_id is not None: - self._pending_requests[original_id] = message.resolver - - logger.debug("Delivering queued message for task %s: %s", task_id, message.type) - - # Send the message with relatedRequestId for routing - session_message = SessionMessage( - message=message.message, - metadata=ServerMessageMetadata(related_request_id=request_id), - ) - await self.send_message(session, session_message) - - async def _wait_for_task_update(self, task_id: str) -> None: - """Wait for task to be updated (status change or new message). - - Races between store update and queue message - first one wins. - """ - async with anyio.create_task_group() as tg: - - async def wait_for_store() -> None: - try: - await self._store.wait_for_update(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - async def wait_for_queue() -> None: - try: - await self._queue.wait_for_message(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - tg.start_soon(wait_for_store) - tg.start_soon(wait_for_queue) - - def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: - """Route a response back to the waiting resolver. - - This is called when a response arrives for a queued request. - - Args: - request_id: The request ID from the response - response: The response data - - Returns: - True if response was routed, False if no pending request - """ - resolver = self._pending_requests.pop(request_id, None) - if resolver is not None and not resolver.done(): - resolver.set_result(response) - return True - return False - - def route_error(self, request_id: RequestId, error: ErrorData) -> bool: - """Route an error back to the waiting resolver. - - Args: - request_id: The request ID from the error response - error: The error data - - Returns: - True if error was routed, False if no pending request - """ - resolver = self._pending_requests.pop(request_id, None) - if resolver is not None and not resolver.done(): - resolver.set_exception(MCPError.from_error_data(error)) - return True - return False diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py deleted file mode 100644 index b542195048..0000000000 --- a/src/mcp/server/experimental/task_support.py +++ /dev/null @@ -1,116 +0,0 @@ -"""TaskSupport - Configuration for experimental task support. - -This module provides the TaskSupport class which encapsulates all the -infrastructure needed for task-augmented requests: store, queue, and handler. -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -from anyio.abc import TaskGroup - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.server.session import ServerSession -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue -from mcp.shared.experimental.tasks.store import TaskStore - - -@dataclass -class TaskSupport: - """Configuration for experimental task support. - - Encapsulates the task store, message queue, result handler, and task group - for spawning background work. - - When enabled on a server, this automatically: - - Configures response routing for each session - - Provides default handlers for task operations - - Manages a task group for background task execution - - Example: - Simple in-memory setup: - - ```python - server.experimental.enable_tasks() - ``` - - Custom store/queue for distributed systems: - - ```python - server.experimental.enable_tasks( - store=RedisTaskStore(redis_url), - queue=RedisTaskMessageQueue(redis_url), - ) - ``` - """ - - store: TaskStore - queue: TaskMessageQueue - handler: TaskResultHandler = field(init=False) - _task_group: TaskGroup | None = field(init=False, default=None) - - def __post_init__(self) -> None: - """Create the result handler from store and queue.""" - self.handler = TaskResultHandler(self.store, self.queue) - - @property - def task_group(self) -> TaskGroup: - """Get the task group for spawning background work. - - Raises: - RuntimeError: If not within a run() context - """ - if self._task_group is None: - raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.") - return self._task_group - - @asynccontextmanager - async def run(self) -> AsyncIterator[None]: - """Run the task support lifecycle. - - This creates a task group for spawning background task work. - Called automatically by Server.run(). - - Usage: - async with task_support.run(): - # Task group is now available - ... - """ - async with anyio.create_task_group() as tg: - self._task_group = tg - try: - yield - finally: - self._task_group = None - - def configure_session(self, session: ServerSession) -> None: - """Configure a session for task support. - - This registers the result handler as a response router so that - responses to queued requests (elicitation, sampling) are routed - back to the waiting resolvers. - - Called automatically by Server.run() for each new session. - - Args: - session: The session to configure - """ - session.add_response_router(self.handler) - - @classmethod - def in_memory(cls) -> "TaskSupport": - """Create in-memory task support. - - Suitable for development, testing, and single-process servers. - For distributed systems, provide custom store and queue implementations. - - Returns: - TaskSupport configured with in-memory store and queue - """ - return cls( - store=InMemoryTaskStore(), - queue=InMemoryTaskMessageQueue(), - ) diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py deleted file mode 100644 index 5a907b6407..0000000000 --- a/src/mcp/server/lowlevel/experimental.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Experimental handlers for the low-level MCP server. - -WARNING: These APIs are experimental and may change without notice. -""" - -from __future__ import annotations - -import logging -from collections.abc import Awaitable, Callable -from typing import Any, Generic - -from typing_extensions import TypeVar - -from mcp.server.context import ServerRequestContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import cancel_task -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_PARAMS, - CancelTaskRequestParams, - CancelTaskResult, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - PaginatedRequestParams, - ServerCapabilities, - ServerTasksCapability, - ServerTasksRequestsCapability, - TasksCallCapability, - TasksCancelCapability, - TasksListCapability, - TasksToolsCapability, -) - -logger = logging.getLogger(__name__) - -LifespanResultT = TypeVar("LifespanResultT", default=Any) - - -class ExperimentalHandlers(Generic[LifespanResultT]): - """Experimental request/notification handlers. - - WARNING: These APIs are experimental and may change without notice. - """ - - def __init__( - self, - add_request_handler: Callable[ - [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None - ], - has_handler: Callable[[str], bool], - ) -> None: - self._add_request_handler = add_request_handler - self._has_handler = has_handler - self._task_support: TaskSupport | None = None - - @property - def task_support(self) -> TaskSupport | None: - """Get the task support configuration, if enabled.""" - return self._task_support - - def update_capabilities(self, capabilities: ServerCapabilities) -> None: - # Only add tasks capability if handlers are registered - if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): - return - - capabilities.tasks = ServerTasksCapability() - if self._has_handler("tasks/list"): - capabilities.tasks.list = TasksListCapability() - if self._has_handler("tasks/cancel"): - capabilities.tasks.cancel = TasksCancelCapability() - - capabilities.tasks.requests = ServerTasksRequestsCapability( - tools=TasksToolsCapability(call=TasksCallCapability()) - ) # assuming always supported for now - - def enable_tasks( - self, - store: TaskStore | None = None, - queue: TaskMessageQueue | None = None, - *, - on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] - | None = None, - on_task_result: Callable[ - [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] - ] - | None = None, - on_list_tasks: Callable[ - [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] - ] - | None = None, - on_cancel_task: Callable[ - [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] - ] - | None = None, - ) -> TaskSupport: - """Enable experimental task support. - - This sets up the task infrastructure and registers handlers for - tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers - can be provided via the on_* kwargs; any not provided will use defaults. - - Args: - store: Custom TaskStore implementation (defaults to InMemoryTaskStore) - queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) - on_get_task: Custom handler for tasks/get - on_task_result: Custom handler for tasks/result - on_list_tasks: Custom handler for tasks/list - on_cancel_task: Custom handler for tasks/cancel - - Returns: - The TaskSupport configuration object - - Example: - Simple in-memory setup: - - ```python - server.experimental.enable_tasks() - ``` - - Custom store/queue for distributed systems: - - ```python - server.experimental.enable_tasks( - store=RedisTaskStore(redis_url), - queue=RedisTaskMessageQueue(redis_url), - ) - ``` - - WARNING: This API is experimental and may change without notice. - """ - if store is None: - store = InMemoryTaskStore() - if queue is None: - queue = InMemoryTaskMessageQueue() - - self._task_support = TaskSupport(store=store, queue=queue) - task_support = self._task_support - - # Register user-provided handlers - if on_get_task is not None: - self._add_request_handler("tasks/get", on_get_task) - if on_task_result is not None: - self._add_request_handler("tasks/result", on_task_result) - if on_list_tasks is not None: - self._add_request_handler("tasks/list", on_list_tasks) - if on_cancel_task is not None: - self._add_request_handler("tasks/cancel", on_cancel_task) - - # Fill in defaults for any not provided - if not self._has_handler("tasks/get"): - - async def _default_get_task( - ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams - ) -> GetTaskResult: - task = await task_support.store.get_task(params.task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - self._add_request_handler("tasks/get", _default_get_task) - - if not self._has_handler("tasks/result"): - - async def _default_get_task_result( - ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - assert ctx.request_id is not None - req = GetTaskPayloadRequest(params=params) - result = await task_support.handler.handle(req, ctx.session, ctx.request_id) - return result - - self._add_request_handler("tasks/result", _default_get_task_result) - - if not self._has_handler("tasks/list"): - - async def _default_list_tasks( - ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None - ) -> ListTasksResult: - cursor = params.cursor if params else None - tasks, next_cursor = await task_support.store.list_tasks(cursor) - return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - - self._add_request_handler("tasks/list", _default_list_tasks) - - if not self._has_handler("tasks/cancel"): - - async def _default_cancel_task( - ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams - ) -> CancelTaskResult: - result = await cancel_task(task_support.store, params.task_id) - return result - - self._add_request_handler("tasks/cancel", _default_cancel_task) - - return task_support diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 5e4e2e6f5b..37127c5621 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -59,8 +59,6 @@ async def main(): from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental -from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore @@ -121,7 +119,7 @@ def __init__( | None = None, on_call_tool: Callable[ [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], - Awaitable[types.CallToolResult | types.CreateTaskResult], + Awaitable[types.CallToolResult], ] | None = None, on_list_resources: Callable[ @@ -197,7 +195,6 @@ def __init__( self._notification_handlers: dict[ str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] ] = {} - self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) @@ -242,10 +239,6 @@ def _add_request_handler( """Add a request handler, silently replacing any existing handler for the same method.""" self._request_handlers[method] = handler - def _has_handler(self, method: str) -> bool: - """Check if a handler is registered for the given method.""" - return method in self._request_handlers or method in self._notification_handlers - # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -323,25 +316,8 @@ def get_capabilities( experimental=experimental_capabilities, completions=completions_capability, ) - if self._experimental_handlers: - self._experimental_handlers.update_capabilities(capabilities) return capabilities - @property - def experimental(self) -> ExperimentalHandlers[LifespanResultT]: - """Experimental APIs for tasks and other features. - - WARNING: These APIs are experimental and may change without notice. - """ - - # We create this inline so we only add these capabilities _if_ they're actually used - if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers( - add_request_handler=self._add_request_handler, - has_handler=self._has_handler, - ) - return self._experimental_handlers - @property def session_manager(self) -> StreamableHTTPSessionManager: """Get the StreamableHTTP session manager. @@ -383,12 +359,6 @@ async def run( ) ) - # Configure task support for this session if enabled - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - if task_support is not None: - task_support.configure_session(session) - await stack.enter_async_context(task_support.run()) - async with anyio.create_task_group() as tg: try: async for message in session.incoming_messages: @@ -476,23 +446,11 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: # pragma: no branch - task_metadata = getattr(req.params, "task", None) ctx = ServerRequestContext( request_id=message.request_id, meta=message.request_meta, session=session, lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), request=request_data, close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, @@ -543,17 +501,9 @@ async def _handle_notification( logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None ctx = ServerRequestContext( session=session, lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), ) await handler(ctx, notify.params) except Exception: # pragma: no cover diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index fc2f97a9cb..3fc7bbf0d3 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -37,13 +37,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from pydantic import AnyUrl, TypeAdapter from mcp import types -from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import StatelessModeNotSupported -from mcp.shared.experimental.tasks.capabilities import check_tasks_capability -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -76,7 +73,6 @@ class ServerSession( ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None - _experimental_features: ExperimentalServerSessionFeatures | None = None def __init__( self, @@ -109,16 +105,6 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification] def client_params(self) -> types.InitializeRequestParams | None: return self._client_params - @property - def experimental(self) -> ExperimentalServerSessionFeatures: - """Experimental APIs for server→client task operations. - - WARNING: These APIs are experimental and may change without notice. - """ - if self._experimental_features is None: - self._experimental_features = ExperimentalServerSessionFeatures(self) - return self._experimental_features - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" if self._client_params is None: # pragma: lax no cover @@ -150,12 +136,6 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False - if capability.tasks is not None: # pragma: lax no cover - if client_caps.tasks is None: - return False - if not check_tasks_capability(capability.tasks, client_caps.tasks): - return False - return True async def _receive_loop(self) -> None: @@ -509,181 +489,6 @@ async def send_elicit_complete( related_request_id, ) - def _build_elicit_form_request( - self, - message: str, - requested_schema: types.ElicitRequestedSchema, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a form mode elicitation request without sending it. - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_elicit_url_request( - self, - message: str, - url: str, - elicitation_id: str, - related_task_id: str | None = None, - ) -> types.JSONRPCRequest: - """Build a URL mode elicitation request without sending it. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestURLParams( - message=message, - url=url, - elicitation_id=elicitation_id, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_create_message_request( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a sampling/createMessage request without sending it. - - Args: - messages: The conversation messages to send - max_tokens: Maximum number of tokens to generate - system_prompt: Optional system prompt - include_context: Optional context inclusion setting - temperature: Optional sampling temperature - stop_sequences: Optional stop sequences - metadata: Optional metadata to pass through to the LLM provider - model_preferences: Optional model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.CreateMessageRequestParams( - messages=messages, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - max_tokens=max_tokens, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="sampling/createMessage", - params=params_data, - ) - - async def send_message(self, message: SessionMessage) -> None: - """Send a raw session message. - - This is primarily used by TaskResultHandler to deliver queued messages - (elicitation/sampling requests) to the client during task execution. - - WARNING: This is a low-level experimental method that may change without - notice. Prefer using higher-level methods like send_notification() or - send_request() for normal operations. - - Args: - message: The session message to send - """ - await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py index 5708628074..08f5754f1e 100644 --- a/src/mcp/server/validation.py +++ b/src/mcp/server/validation.py @@ -1,7 +1,6 @@ """Shared validation functions for server requests. -This module provides validation logic for sampling and elicitation requests -that is shared across normal and task-augmented code paths. +This module provides validation logic for sampling and elicitation requests. """ from mcp.shared.exceptions import MCPError diff --git a/src/mcp/shared/experimental/__init__.py b/src/mcp/shared/experimental/__init__.py deleted file mode 100644 index fa6940acc6..0000000000 --- a/src/mcp/shared/experimental/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Pure experimental MCP features (no server dependencies). - -WARNING: These APIs are experimental and may change without notice. - -For server-integrated experimental features, use mcp.server.experimental. -""" diff --git a/src/mcp/shared/experimental/tasks/__init__.py b/src/mcp/shared/experimental/tasks/__init__.py deleted file mode 100644 index 52793e408b..0000000000 --- a/src/mcp/shared/experimental/tasks/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Pure task state management for MCP. - -WARNING: These APIs are experimental and may change without notice. - -Import directly from submodules: -- mcp.shared.experimental.tasks.store.TaskStore -- mcp.shared.experimental.tasks.context.TaskContext -- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore -- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue -- mcp.shared.experimental.tasks.helpers.is_terminal -""" diff --git a/src/mcp/shared/experimental/tasks/capabilities.py b/src/mcp/shared/experimental/tasks/capabilities.py deleted file mode 100644 index 51fe64ecc3..0000000000 --- a/src/mcp/shared/experimental/tasks/capabilities.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Tasks capability checking utilities. - -This module provides functions for checking and requiring task-related -capabilities. All tasks capability logic is centralized here to keep -the main session code clean. - -WARNING: These APIs are experimental and may change without notice. -""" - -from mcp.shared.exceptions import MCPError -from mcp.types import INVALID_REQUEST, ClientCapabilities, ClientTasksCapability - - -def check_tasks_capability( - required: ClientTasksCapability, - client: ClientTasksCapability, -) -> bool: - """Check if client's tasks capability matches the required capability. - - Args: - required: The capability being checked for - client: The client's declared capabilities - - Returns: - True if client has the required capability, False otherwise - """ - if required.requests is None: - return True - if client.requests is None: - return False - - # Check elicitation.create - if required.requests.elicitation is not None: - if client.requests.elicitation is None: - return False - if required.requests.elicitation.create is not None: - if client.requests.elicitation.create is None: - return False - - # Check sampling.createMessage - if required.requests.sampling is not None: - if client.requests.sampling is None: - return False - if required.requests.sampling.create_message is not None: - if client.requests.sampling.create_message is None: - return False - - return True - - -def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool: - """Check if capabilities include task-augmented elicitation support.""" - if caps.tasks is None: - return False - if caps.tasks.requests is None: - return False - if caps.tasks.requests.elicitation is None: - return False - return caps.tasks.requests.elicitation.create is not None - - -def has_task_augmented_sampling(caps: ClientCapabilities) -> bool: - """Check if capabilities include task-augmented sampling support.""" - if caps.tasks is None: - return False - if caps.tasks.requests is None: - return False - if caps.tasks.requests.sampling is None: - return False - return caps.tasks.requests.sampling.create_message is not None - - -def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None: - """Raise MCPError if client doesn't support task-augmented elicitation. - - Args: - client_caps: The client's declared capabilities, or None if not initialized - - Raises: - MCPError: If client doesn't support task-augmented elicitation - """ - if client_caps is None or not has_task_augmented_elicitation(client_caps): - raise MCPError(code=INVALID_REQUEST, message="Client does not support task-augmented elicitation") - - -def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None: - """Raise MCPError if client doesn't support task-augmented sampling. - - Args: - client_caps: The client's declared capabilities, or None if not initialized - - Raises: - MCPError: If client doesn't support task-augmented sampling - """ - if client_caps is None or not has_task_augmented_sampling(client_caps): - raise MCPError(code=INVALID_REQUEST, message="Client does not support task-augmented sampling") diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py deleted file mode 100644 index ed0d2b91b6..0000000000 --- a/src/mcp/shared/experimental/tasks/context.py +++ /dev/null @@ -1,95 +0,0 @@ -"""TaskContext - Pure task state management. - -This module provides TaskContext, which manages task state without any -server/session dependencies. It can be used standalone for distributed -workers or wrapped by ServerTaskContext for full server integration. -""" - -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task - - -class TaskContext: - """Pure task state management - no session dependencies. - - This class handles: - - Task state (status, result) - - Cancellation tracking - - Store interactions - - For server-integrated features (elicit, create_message, notifications), - use ServerTaskContext from mcp.server.experimental. - - Example (distributed worker): - async def worker_job(task_id: str): - store = RedisTaskStore(redis_url) - task = await store.get_task(task_id) - ctx = TaskContext(task=task, store=store) - - await ctx.update_status("Working...") - result = await do_work() - await ctx.complete(result) - """ - - def __init__(self, task: Task, store: TaskStore): - self._task = task - self._store = store - self._cancelled = False - - @property - def task_id(self) -> str: - """The task identifier.""" - return self._task.task_id - - @property - def task(self) -> Task: - """The current task state.""" - return self._task - - @property - def is_cancelled(self) -> bool: - """Whether cancellation has been requested.""" - return self._cancelled - - def request_cancellation(self) -> None: - """Request cancellation of this task. - - This sets is_cancelled=True. Task work should check this - periodically and exit gracefully if set. - """ - self._cancelled = True - - async def update_status(self, message: str) -> None: - """Update the task's status message. - - Args: - message: The new status message - """ - self._task = await self._store.update_task( - self.task_id, - status_message=message, - ) - - async def complete(self, result: Result) -> None: - """Mark the task as completed with the given result. - - Args: - result: The task result - """ - await self._store.store_result(self.task_id, result) - self._task = await self._store.update_task( - self.task_id, - status=TASK_STATUS_COMPLETED, - ) - - async def fail(self, error: str) -> None: - """Mark the task as failed with an error message. - - Args: - error: The error message - """ - self._task = await self._store.update_task( - self.task_id, - status=TASK_STATUS_FAILED, - status_message=error, - ) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py deleted file mode 100644 index 3f91cd0d06..0000000000 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Helper functions for pure task management. - -These helpers work with pure TaskContext and don't require server dependencies. -For server-integrated task helpers, use mcp.server.experimental. -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from datetime import datetime, timezone -from uuid import uuid4 - -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_PARAMS, - TASK_STATUS_CANCELLED, - TASK_STATUS_COMPLETED, - TASK_STATUS_FAILED, - TASK_STATUS_WORKING, - CancelTaskResult, - Task, - TaskMetadata, - TaskStatus, -) - -# Metadata key for model-immediate-response (per MCP spec) -# Servers MAY include this in CreateTaskResult._meta to provide an immediate -# response string while the task executes in the background. -MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response" - -# Metadata key for associating requests with a task (per MCP spec) -RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task" - - -def is_terminal(status: TaskStatus) -> bool: - """Check if a task status represents a terminal state. - - Terminal states are those where the task has finished and will not change. - - Args: - status: The task status to check - - Returns: - True if the status is terminal (completed, failed, or cancelled) - """ - return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED) - - -async def cancel_task( - store: TaskStore, - task_id: str, -) -> CancelTaskResult: - """Cancel a task with spec-compliant validation. - - Per spec: "Receivers MUST reject cancellation of terminal status tasks - with -32602 (Invalid params)" - - This helper validates that the task exists and is not in a terminal state - before setting it to "cancelled". - - Args: - store: The task store - task_id: The task identifier to cancel - - Returns: - CancelTaskResult with the cancelled task state - - Raises: - MCPError: With INVALID_PARAMS (-32602) if: - - Task does not exist - - Task is already in a terminal state (completed, failed, cancelled) - - Example: - ```python - async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: - return await cancel_task(store, params.task_id) - ``` - """ - task = await store.get_task(task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") - - if is_terminal(task.status): - raise MCPError(code=INVALID_PARAMS, message=f"Cannot cancel task in terminal state '{task.status}'") - - # Update task to cancelled status - cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED) - return CancelTaskResult(**cancelled_task.model_dump()) - - -def generate_task_id() -> str: - """Generate a unique task ID.""" - return str(uuid4()) - - -def create_task_state( - metadata: TaskMetadata, - task_id: str | None = None, -) -> Task: - """Create a Task object with initial state. - - This is a helper for TaskStore implementations. - - Args: - metadata: Task metadata - task_id: Optional task ID (generated if not provided) - - Returns: - A new Task in "working" status - """ - now = datetime.now(timezone.utc) - return Task( - task_id=task_id or generate_task_id(), - status=TASK_STATUS_WORKING, - created_at=now, - last_updated_at=now, - ttl=metadata.ttl, - poll_interval=500, # Default 500ms poll interval - ) - - -@asynccontextmanager -async def task_execution( - task_id: str, - store: TaskStore, -) -> AsyncIterator[TaskContext]: - """Context manager for safe task execution (pure, no server dependencies). - - Loads a task from the store and provides a TaskContext for the work. - If an unhandled exception occurs, the task is automatically marked as failed - and the exception is suppressed (since the failure is captured in task state). - - This is useful for distributed workers that don't have a server session. - - Args: - task_id: The task identifier to execute - store: The task store (must be accessible by the worker) - - Yields: - TaskContext for updating status and completing/failing the task - - Raises: - ValueError: If the task is not found in the store - - Example (distributed worker): - async def worker_process(task_id: str): - store = RedisTaskStore(redis_url) - async with task_execution(task_id, store) as ctx: - await ctx.update_status("Working...") - result = await do_work() - await ctx.complete(result) - """ - task = await store.get_task(task_id) - if task is None: - raise ValueError(f"Task {task_id} not found") - - ctx = TaskContext(task, store) - try: - yield ctx - except Exception as e: - # Auto-fail the task if an exception occurs and task isn't already terminal - # Exception is suppressed since failure is captured in task state - if not is_terminal(ctx.task.status): - await ctx.fail(str(e)) - # Don't re-raise - the failure is recorded in task state diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py deleted file mode 100644 index 42f4fb7035..0000000000 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ /dev/null @@ -1,217 +0,0 @@ -"""In-memory implementation of TaskStore for demonstration purposes. - -This implementation stores all tasks in memory and provides automatic cleanup -based on the TTL duration specified in the task metadata using lazy expiration. - -Note: This is not suitable for production use as all data is lost on restart. -For production, consider implementing TaskStore with a database or distributed cache. -""" - -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone - -import anyio - -from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import Result, Task, TaskMetadata, TaskStatus - - -@dataclass -class StoredTask: - """Internal storage representation of a task.""" - - task: Task - result: Result | None = None - # Time when this task should be removed (None = never) - expires_at: datetime | None = field(default=None) - - -class InMemoryTaskStore(TaskStore): - """A simple in-memory implementation of TaskStore. - - Features: - - Automatic TTL-based cleanup (lazy expiration) - - Thread-safe for single-process async use - - Pagination support for list_tasks - - Limitations: - - All data lost on restart - - Not suitable for distributed systems - - No persistence - - For production, implement TaskStore with Redis, PostgreSQL, etc. - """ - - def __init__(self, page_size: int = 10) -> None: - self._tasks: dict[str, StoredTask] = {} - self._page_size = page_size - self._update_events: dict[str, anyio.Event] = {} - - def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: - """Calculate expiry time from TTL in milliseconds.""" - if ttl_ms is None: - return None - return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms) - - def _is_expired(self, stored: StoredTask) -> bool: - """Check if a task has expired.""" - if stored.expires_at is None: - return False - return datetime.now(timezone.utc) >= stored.expires_at - - def _cleanup_expired(self) -> None: - """Remove all expired tasks. Called lazily during access operations.""" - expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] - for task_id in expired_ids: - del self._tasks[task_id] - - async def create_task( - self, - metadata: TaskMetadata, - task_id: str | None = None, - ) -> Task: - """Create a new task with the given metadata.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - task = create_task_state(metadata, task_id) - - if task.task_id in self._tasks: - raise ValueError(f"Task with ID {task.task_id} already exists") - - stored = StoredTask( - task=task, - expires_at=self._calculate_expiry(metadata.ttl), - ) - self._tasks[task.task_id] = stored - - # Return a copy to prevent external modification - return Task(**task.model_dump()) - - async def get_task(self, task_id: str) -> Task | None: - """Get a task by ID.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - stored = self._tasks.get(task_id) - if stored is None: - return None - - # Return a copy to prevent external modification - return Task(**stored.task.model_dump()) - - async def update_task( - self, - task_id: str, - status: TaskStatus | None = None, - status_message: str | None = None, - ) -> Task: - """Update a task's status and/or message.""" - stored = self._tasks.get(task_id) - if stored is None: - raise ValueError(f"Task with ID {task_id} not found") - - # Per spec: Terminal states MUST NOT transition to any other status - if status is not None and status != stored.task.status and is_terminal(stored.task.status): - raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'") - - status_changed = False - if status is not None and stored.task.status != status: - stored.task.status = status - status_changed = True - - if status_message is not None: - stored.task.status_message = status_message - - # Update last_updated_at on any change - stored.task.last_updated_at = datetime.now(timezone.utc) - - # If task is now terminal and has TTL, reset expiry timer - if status is not None and is_terminal(status) and stored.task.ttl is not None: - stored.expires_at = self._calculate_expiry(stored.task.ttl) - - # Notify waiters if status changed - if status_changed: - await self.notify_update(task_id) - - return Task(**stored.task.model_dump()) - - async def store_result(self, task_id: str, result: Result) -> None: - """Store the result for a task.""" - stored = self._tasks.get(task_id) - if stored is None: - raise ValueError(f"Task with ID {task_id} not found") - - stored.result = result - - async def get_result(self, task_id: str) -> Result | None: - """Get the stored result for a task.""" - stored = self._tasks.get(task_id) - if stored is None: - return None - - return stored.result - - async def list_tasks( - self, - cursor: str | None = None, - ) -> tuple[list[Task], str | None]: - """List tasks with pagination.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - all_task_ids = list(self._tasks.keys()) - - start_index = 0 - if cursor is not None: - try: - cursor_index = all_task_ids.index(cursor) - start_index = cursor_index + 1 - except ValueError: - raise ValueError(f"Invalid cursor: {cursor}") - - page_task_ids = all_task_ids[start_index : start_index + self._page_size] - tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] - - # Determine next cursor - next_cursor = None - if start_index + self._page_size < len(all_task_ids) and page_task_ids: - next_cursor = page_task_ids[-1] - - return tasks, next_cursor - - async def delete_task(self, task_id: str) -> bool: - """Delete a task.""" - if task_id not in self._tasks: - return False - - del self._tasks[task_id] - return True - - async def wait_for_update(self, task_id: str) -> None: - """Wait until the task status changes.""" - if task_id not in self._tasks: - raise ValueError(f"Task with ID {task_id} not found") - - # Create a fresh event for waiting (anyio.Event can't be cleared) - self._update_events[task_id] = anyio.Event() - event = self._update_events[task_id] - await event.wait() - - async def notify_update(self, task_id: str) -> None: - """Signal that a task has been updated.""" - if task_id in self._update_events: - self._update_events[task_id].set() - - # --- Testing/debugging helpers --- - - def cleanup(self) -> None: - """Cleanup all tasks (useful for testing or graceful shutdown).""" - self._tasks.clear() - self._update_events.clear() - - def get_all_tasks(self) -> list[Task]: - """Get all tasks (useful for debugging). Returns copies to prevent modification.""" - self._cleanup_expired() - return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py deleted file mode 100644 index e17c4a8650..0000000000 --- a/src/mcp/shared/experimental/tasks/message_queue.py +++ /dev/null @@ -1,230 +0,0 @@ -"""TaskMessageQueue - FIFO queue for task-related messages. - -This implements the core message queue pattern from the MCP Tasks spec. -When a handler needs to send a request (like elicitation) during a task-augmented -request, the message is enqueued instead of sent directly. Messages are delivered -to the client only through the `tasks/result` endpoint. - -This pattern enables: -1. Decoupling request handling from message delivery -2. Proper bidirectional communication via the tasks/result stream -3. Automatic status management (working <-> input_required) -""" - -from abc import ABC, abstractmethod -from collections import deque -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Literal - -import anyio - -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId - - -@dataclass -class QueuedMessage: - """A message queued for delivery via tasks/result. - - Messages are stored with their type and a resolver for requests - that expect responses. - """ - - type: Literal["request", "notification"] - """Whether this is a request (expects response) or notification (one-way).""" - - message: JSONRPCRequest | JSONRPCNotification - """The JSON-RPC message to send.""" - - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - """When the message was enqueued.""" - - resolver: Resolver[dict[str, Any]] | None = None - """Resolver to set when response arrives (only for requests).""" - - original_request_id: RequestId | None = None - """The original request ID used internally, for routing responses back.""" - - -class TaskMessageQueue(ABC): - """Abstract interface for task message queuing. - - This is a FIFO queue that stores messages to be delivered via `tasks/result`. - When a task-augmented handler calls elicit() or sends a notification, the - message is enqueued here instead of being sent directly to the client. - - The `tasks/result` handler then dequeues and sends these messages through - the transport, with `relatedRequestId` set to the tasks/result request ID - so responses are routed correctly. - - Implementations can use in-memory storage, Redis, etc. - """ - - @abstractmethod - async def enqueue(self, task_id: str, message: QueuedMessage) -> None: - """Add a message to the queue for a task. - - Args: - task_id: The task identifier - message: The message to enqueue - """ - - @abstractmethod - async def dequeue(self, task_id: str) -> QueuedMessage | None: - """Remove and return the next message from the queue. - - Args: - task_id: The task identifier - - Returns: - The next message, or None if queue is empty - """ - - @abstractmethod - async def peek(self, task_id: str) -> QueuedMessage | None: - """Return the next message without removing it. - - Args: - task_id: The task identifier - - Returns: - The next message, or None if queue is empty - """ - - @abstractmethod - async def is_empty(self, task_id: str) -> bool: - """Check if the queue is empty for a task. - - Args: - task_id: The task identifier - - Returns: - True if no messages are queued - """ - - @abstractmethod - async def clear(self, task_id: str) -> list[QueuedMessage]: - """Remove and return all messages from the queue. - - This is useful for cleanup when a task is cancelled or completed. - - Args: - task_id: The task identifier - - Returns: - All queued messages (may be empty) - """ - - @abstractmethod - async def wait_for_message(self, task_id: str) -> None: - """Wait until a message is available in the queue. - - This blocks until either: - 1. A message is enqueued for this task - 2. The wait is cancelled - - Args: - task_id: The task identifier - """ - - @abstractmethod - async def notify_message_available(self, task_id: str) -> None: - """Signal that a message is available for a task. - - This wakes up any coroutines waiting in wait_for_message(). - - Args: - task_id: The task identifier - """ - - -class InMemoryTaskMessageQueue(TaskMessageQueue): - """In-memory implementation of TaskMessageQueue. - - This is suitable for single-process servers. For distributed systems, - implement TaskMessageQueue with Redis, RabbitMQ, etc. - - Features: - - FIFO ordering per task - - Async wait for message availability - - Thread-safe for single-process async use - """ - - def __init__(self) -> None: - self._queues: dict[str, deque[QueuedMessage]] = {} - self._events: dict[str, anyio.Event] = {} - - def _get_queue(self, task_id: str) -> deque[QueuedMessage]: - """Get or create the queue for a task.""" - if task_id not in self._queues: - self._queues[task_id] = deque() - return self._queues[task_id] - - async def enqueue(self, task_id: str, message: QueuedMessage) -> None: - """Add a message to the queue.""" - queue = self._get_queue(task_id) - queue.append(message) - # Signal that a message is available - await self.notify_message_available(task_id) - - async def dequeue(self, task_id: str) -> QueuedMessage | None: - """Remove and return the next message.""" - queue = self._get_queue(task_id) - if not queue: - return None - return queue.popleft() - - async def peek(self, task_id: str) -> QueuedMessage | None: - """Return the next message without removing it.""" - queue = self._get_queue(task_id) - if not queue: - return None - return queue[0] - - async def is_empty(self, task_id: str) -> bool: - """Check if the queue is empty.""" - queue = self._get_queue(task_id) - return len(queue) == 0 - - async def clear(self, task_id: str) -> list[QueuedMessage]: - """Remove and return all messages.""" - queue = self._get_queue(task_id) - messages = list(queue) - queue.clear() - return messages - - async def wait_for_message(self, task_id: str) -> None: - """Wait until a message is available.""" - # Check if there are already messages - if not await self.is_empty(task_id): - return - - # Create a fresh event for waiting (anyio.Event can't be cleared) - self._events[task_id] = anyio.Event() - event = self._events[task_id] - - # Double-check after creating event (avoid race condition) - if not await self.is_empty(task_id): - return - - # Wait for a new message - await event.wait() - - async def notify_message_available(self, task_id: str) -> None: - """Signal that a message is available.""" - if task_id in self._events: - self._events[task_id].set() - - def cleanup(self, task_id: str | None = None) -> None: - """Clean up queues and events. - - Args: - task_id: If provided, clean up only this task. Otherwise clean up all. - """ - if task_id is not None: - self._queues.pop(task_id, None) - self._events.pop(task_id, None) - else: - self._queues.clear() - self._events.clear() diff --git a/src/mcp/shared/experimental/tasks/polling.py b/src/mcp/shared/experimental/tasks/polling.py deleted file mode 100644 index e4e13b6640..0000000000 --- a/src/mcp/shared/experimental/tasks/polling.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Shared polling utilities for task operations. - -This module provides generic polling logic that works for both client→server -and server→client task polling. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import AsyncIterator, Awaitable, Callable - -import anyio - -from mcp.shared.experimental.tasks.helpers import is_terminal -from mcp.types import GetTaskResult - - -async def poll_until_terminal( - get_task: Callable[[str], Awaitable[GetTaskResult]], - task_id: str, - default_interval_ms: int = 500, -) -> AsyncIterator[GetTaskResult]: - """Poll a task until it reaches terminal status. - - This is a generic utility that works for both client→server and server→client - polling. The caller provides the get_task function appropriate for their direction. - - Args: - get_task: Async function that takes task_id and returns GetTaskResult - task_id: The task to poll - default_interval_ms: Fallback poll interval if server doesn't specify - - Yields: - GetTaskResult for each poll - """ - while True: - status = await get_task(task_id) - yield status - - if is_terminal(status.status): - break - - interval_ms = status.poll_interval if status.poll_interval is not None else default_interval_ms - await anyio.sleep(interval_ms / 1000) diff --git a/src/mcp/shared/experimental/tasks/resolver.py b/src/mcp/shared/experimental/tasks/resolver.py deleted file mode 100644 index 1d233a9309..0000000000 --- a/src/mcp/shared/experimental/tasks/resolver.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Resolver - An anyio-compatible future-like object for async result passing. - -This provides a simple way to pass a result (or exception) from one coroutine -to another without depending on asyncio.Future. -""" - -from typing import Generic, TypeVar, cast - -import anyio - -T = TypeVar("T") - - -class Resolver(Generic[T]): - """A simple resolver for passing results between coroutines. - - Unlike asyncio.Future, this works with any anyio-compatible async backend. - - Usage: - resolver: Resolver[str] = Resolver() - - # In one coroutine: - resolver.set_result("hello") - - # In another coroutine: - result = await resolver.wait() # returns "hello" - """ - - def __init__(self) -> None: - self._event = anyio.Event() - self._value: T | None = None - self._exception: BaseException | None = None - - def set_result(self, value: T) -> None: - """Set the result value and wake up waiters.""" - if self._event.is_set(): - raise RuntimeError("Resolver already completed") - self._value = value - self._event.set() - - def set_exception(self, exc: BaseException) -> None: - """Set an exception and wake up waiters.""" - if self._event.is_set(): - raise RuntimeError("Resolver already completed") - self._exception = exc - self._event.set() - - async def wait(self) -> T: - """Wait for the result and return it, or raise the exception.""" - await self._event.wait() - if self._exception is not None: - raise self._exception - # If we reach here, set_result() was called, so _value is set - return cast(T, self._value) - - def done(self) -> bool: - """Return True if the resolver has been completed.""" - return self._event.is_set() diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py deleted file mode 100644 index 7de97d40ca..0000000000 --- a/src/mcp/shared/experimental/tasks/store.py +++ /dev/null @@ -1,144 +0,0 @@ -"""TaskStore - Abstract interface for task state storage.""" - -from abc import ABC, abstractmethod - -from mcp.types import Result, Task, TaskMetadata, TaskStatus - - -class TaskStore(ABC): - """Abstract interface for task state storage. - - This is a pure storage interface - it doesn't manage execution. - Implementations can use in-memory storage, databases, Redis, etc. - - All methods are async to support various backends. - """ - - @abstractmethod - async def create_task( - self, - metadata: TaskMetadata, - task_id: str | None = None, - ) -> Task: - """Create a new task. - - Args: - metadata: Task metadata (ttl, etc.) - task_id: Optional task ID. If None, implementation should generate one. - - Returns: - The created Task with status="working" - - Raises: - ValueError: If task_id already exists - """ - - @abstractmethod - async def get_task(self, task_id: str) -> Task | None: - """Get a task by ID. - - Args: - task_id: The task identifier - - Returns: - The Task, or None if not found - """ - - @abstractmethod - async def update_task( - self, - task_id: str, - status: TaskStatus | None = None, - status_message: str | None = None, - ) -> Task: - """Update a task's status and/or message. - - Args: - task_id: The task identifier - status: New status (if changing) - status_message: New status message (if changing) - - Returns: - The updated Task - - Raises: - ValueError: If task not found - ValueError: If attempting to transition from a terminal status - (completed, failed, cancelled). Per spec, terminal states - MUST NOT transition to any other status. - """ - - @abstractmethod - async def store_result(self, task_id: str, result: Result) -> None: - """Store the result for a task. - - Args: - task_id: The task identifier - result: The result to store - - Raises: - ValueError: If task not found - """ - - @abstractmethod - async def get_result(self, task_id: str) -> Result | None: - """Get the stored result for a task. - - Args: - task_id: The task identifier - - Returns: - The stored Result, or None if not available - """ - - @abstractmethod - async def list_tasks( - self, - cursor: str | None = None, - ) -> tuple[list[Task], str | None]: - """List tasks with pagination. - - Args: - cursor: Optional cursor for pagination - - Returns: - Tuple of (tasks, next_cursor). next_cursor is None if no more pages. - """ - - @abstractmethod - async def delete_task(self, task_id: str) -> bool: - """Delete a task. - - Args: - task_id: The task identifier - - Returns: - True if deleted, False if not found - """ - - @abstractmethod - async def wait_for_update(self, task_id: str) -> None: - """Wait until the task status changes. - - This blocks until either: - 1. The task status changes - 2. The wait is cancelled - - Used by tasks/result to wait for task completion or status changes. - - Args: - task_id: The task identifier - - Raises: - ValueError: If task not found - """ - - @abstractmethod - async def notify_update(self, task_id: str) -> None: - """Signal that a task has been updated. - - This wakes up any coroutines waiting in wait_for_update(). - - Args: - task_id: The task identifier - """ diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py deleted file mode 100644 index fe24b016f1..0000000000 --- a/src/mcp/shared/response_router.py +++ /dev/null @@ -1,61 +0,0 @@ -"""ResponseRouter - Protocol for pluggable response routing. - -This module defines a protocol for routing JSON-RPC responses to alternative -handlers before falling back to the default response stream mechanism. - -The primary use case is task-augmented requests: when a TaskSession enqueues -a request (like elicitation), the response needs to be routed back to the -waiting resolver instead of the normal response stream. - -Design: -- Protocol-based for testability and flexibility -- Returns bool to indicate if response was handled -- Supports both success responses and errors -""" - -from typing import Any, Protocol - -from mcp.types import ErrorData, RequestId - - -class ResponseRouter(Protocol): - """Protocol for routing responses to alternative handlers. - - Implementations check if they have a pending request for the given ID - and deliver the response/error to the appropriate handler. - - Example: - ```python - class TaskResultHandler(ResponseRouter): - def route_response(self, request_id, response): - resolver = self._pending_requests.pop(request_id, None) - if resolver: - resolver.set_result(response) - return True - return False - ``` - """ - - def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: - """Try to route a response to a pending request handler. - - Args: - request_id: The JSON-RPC request ID from the response - response: The response result data - - Returns: - True if the response was handled, False otherwise - """ - ... # pragma: no cover - - def route_error(self, request_id: RequestId, error: ErrorData) -> bool: - """Try to route an error to a pending request handler. - - Args: - request_id: The JSON-RPC request ID from the error response - error: The error data - - Returns: - True if the error was handled, False otherwise - """ - ... # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9c72a23844..ea5d8833bd 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -17,7 +17,6 @@ from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage -from mcp.shared.response_router import ResponseRouter from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, @@ -183,7 +182,6 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] - _response_routers: list[ResponseRouter] def __init__( self, @@ -199,24 +197,8 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._response_routers = [] self._exit_stack = AsyncExitStack() - def add_response_router(self, router: ResponseRouter) -> None: - """Register a response router to handle responses for non-standard requests. - - Response routers are checked in order before falling back to the default - response stream mechanism. This is used by TaskResultHandler to route - responses for queued task requests back to their resolvers. - - !!! warning - This is an experimental API that may change without notice. - - Args: - router: A ResponseRouter implementation - """ - self._response_routers.append(router) - async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() @@ -477,11 +459,7 @@ def _normalize_request_id(self, response_id: RequestId) -> RequestId: return response_id async def _handle_response(self, message: SessionMessage) -> None: - """Handle an incoming response or error message. - - Checks response routers first (e.g., for task-related responses), - then falls back to the normal response stream mechanism. - """ + """Handle an incoming response or error message.""" # This check is always true at runtime: the caller (_receive_loop) only invokes # this method in the else branch after checking for JSONRPCRequest and # JSONRPCNotification. However, the type checker can't infer this from the @@ -498,20 +476,6 @@ async def _handle_response(self, message: SessionMessage) -> None: # Normalize response ID to handle type mismatches (e.g., "0" vs 0) response_id = self._normalize_request_id(message.message.id) - # First, check response routers (e.g., TaskResultHandler) - if isinstance(message.message, JSONRPCError): - # Route error to routers - for router in self._response_routers: - if router.route_error(response_id, message.message.error): - return # Handled - else: - # Route success response to routers - response_data: dict[str, Any] = message.message.result or {} - for router in self._response_routers: - if router.route_response(response_id, response_data): - return # Handled - - # Fall back to normal response streams stream = self._response_streams.pop(response_id, None) if stream: await stream.send(message.message) diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b442303937..b2d537fb70 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -8,14 +8,6 @@ from mcp.types._types import ( DEFAULT_NEGOTIATED_VERSION, LATEST_PROTOCOL_VERSION, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - TASK_STATUS_CANCELLED, - TASK_STATUS_COMPLETED, - TASK_STATUS_FAILED, - TASK_STATUS_INPUT_REQUIRED, - TASK_STATUS_WORKING, Annotations, AudioContent, BaseMetadata, @@ -25,15 +17,10 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, - CancelTaskRequest, - CancelTaskRequestParams, - CancelTaskResult, ClientCapabilities, ClientNotification, ClientRequest, ClientResult, - ClientTasksCapability, - ClientTasksRequestsCapability, CompleteRequest, CompleteRequestParams, CompleteResult, @@ -46,7 +33,6 @@ CreateMessageRequestParams, CreateMessageResult, CreateMessageResultWithTools, - CreateTaskResult, ElicitationCapability, ElicitationRequiredErrorData, ElicitCompleteNotification, @@ -63,12 +49,6 @@ GetPromptRequest, GetPromptRequestParams, GetPromptResult, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequest, - GetTaskRequestParams, - GetTaskResult, Icon, IconTheme, ImageContent, @@ -86,8 +66,6 @@ ListResourceTemplatesResult, ListRootsRequest, ListRootsResult, - ListTasksRequest, - ListTasksResult, ListToolsRequest, ListToolsResult, LoggingCapability, @@ -114,7 +92,6 @@ ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, - RelatedTaskMetadata, Request, RequestParams, RequestParamsMeta, @@ -142,33 +119,16 @@ ServerNotification, ServerRequest, ServerResult, - ServerTasksCapability, - ServerTasksRequestsCapability, SetLevelRequest, SetLevelRequestParams, StopReason, SubscribeRequest, SubscribeRequestParams, - Task, - TaskExecutionMode, - TaskMetadata, - TasksCallCapability, - TasksCancelCapability, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksListCapability, - TasksSamplingCapability, - TaskStatus, - TaskStatusNotification, - TaskStatusNotificationParams, - TasksToolsCapability, TextContent, TextResourceContents, Tool, ToolAnnotations, ToolChoice, - ToolExecution, ToolListChangedNotification, ToolResultContent, ToolsCapability, @@ -208,16 +168,6 @@ # Protocol version constants "LATEST_PROTOCOL_VERSION", "DEFAULT_NEGOTIATED_VERSION", - # Task execution mode constants - "TASK_FORBIDDEN", - "TASK_OPTIONAL", - "TASK_REQUIRED", - # Task status constants - "TASK_STATUS_CANCELLED", - "TASK_STATUS_COMPLETED", - "TASK_STATUS_FAILED", - "TASK_STATUS_INPUT_REQUIRED", - "TASK_STATUS_WORKING", # Type aliases and variables "ContentBlock", "ElicitRequestedSchema", @@ -229,8 +179,6 @@ "SamplingContent", "SamplingMessageContentBlock", "StopReason", - "TaskExecutionMode", - "TaskStatus", # Base classes "BaseMetadata", "Request", @@ -245,8 +193,6 @@ "EmptyResult", # Capabilities "ClientCapabilities", - "ClientTasksCapability", - "ClientTasksRequestsCapability", "CompletionsCapability", "ElicitationCapability", "FormElicitationCapability", @@ -258,16 +204,6 @@ "SamplingContextCapability", "SamplingToolsCapability", "ServerCapabilities", - "ServerTasksCapability", - "ServerTasksRequestsCapability", - "TasksCancelCapability", - "TasksCallCapability", - "TasksCreateElicitationCapability", - "TasksCreateMessageCapability", - "TasksElicitationCapability", - "TasksListCapability", - "TasksSamplingCapability", - "TasksToolsCapability", "ToolsCapability", "UrlElicitationCapability", # Content types @@ -300,18 +236,12 @@ "ResourceTemplateReference", "Root", "SamplingMessage", - "Task", - "TaskMetadata", - "RelatedTaskMetadata", "Tool", "ToolAnnotations", "ToolChoice", - "ToolExecution", # Requests "CallToolRequest", "CallToolRequestParams", - "CancelTaskRequest", - "CancelTaskRequestParams", "CompleteRequest", "CompleteRequestParams", "CreateMessageRequest", @@ -321,17 +251,12 @@ "ElicitRequestURLParams", "GetPromptRequest", "GetPromptRequestParams", - "GetTaskPayloadRequest", - "GetTaskPayloadRequestParams", - "GetTaskRequest", - "GetTaskRequestParams", "InitializeRequest", "InitializeRequestParams", "ListPromptsRequest", "ListResourcesRequest", "ListResourceTemplatesRequest", "ListRootsRequest", - "ListTasksRequest", "ListToolsRequest", "PingRequest", "ReadResourceRequest", @@ -344,22 +269,17 @@ "UnsubscribeRequestParams", # Results "CallToolResult", - "CancelTaskResult", "CompleteResult", "CreateMessageResult", "CreateMessageResultWithTools", - "CreateTaskResult", "ElicitResult", "ElicitationRequiredErrorData", "GetPromptResult", - "GetTaskPayloadResult", - "GetTaskResult", "InitializeResult", "ListPromptsResult", "ListResourcesResult", "ListResourceTemplatesResult", "ListRootsResult", - "ListTasksResult", "ListToolsResult", "ReadResourceResult", # Notifications @@ -377,8 +297,6 @@ "ResourceUpdatedNotification", "ResourceUpdatedNotificationParams", "RootsListChangedNotification", - "TaskStatusNotification", - "TaskStatusNotificationParams", "ToolListChangedNotification", # Union types for request/response routing "ClientNotification", diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 9005d253af..34800ba12e 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -1,7 +1,6 @@ from __future__ import annotations -from datetime import datetime -from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, TypeAdapter from pydantic.alias_generators import to_camel @@ -30,11 +29,6 @@ IconTheme = Literal["light", "dark"] -TaskExecutionMode = Literal["forbidden", "optional", "required"] -TASK_FORBIDDEN: Final[Literal["forbidden"]] = "forbidden" -TASK_OPTIONAL: Final[Literal["optional"]] = "optional" -TASK_REQUIRED: Final[Literal["required"]] = "required" - class MCPModel(BaseModel): """Base class for all MCP protocol types.""" @@ -55,27 +49,7 @@ class RequestParamsMeta(TypedDict, extra_items=Any): """ -class TaskMetadata(MCPModel): - """Metadata for augmenting a request with task execution. - - Include this in the `task` field of the request parameters. - """ - - ttl: Annotated[int, Field(strict=True)] | None = None - """Requested duration in milliseconds to retain task from creation.""" - - class RequestParams(MCPModel): - task: TaskMetadata | None = None - """ - If specified, the caller is requesting task-augmented execution for this request. - The request will return a CreateTaskResult immediately, and the actual result can be - retrieved later via tasks/result. - - Task augmentation is subject to capability negotiation - receivers MUST declare support - for task augmentation of specific request types in their capabilities. - """ - meta: RequestParamsMeta | None = Field(alias="_meta", default=None) @@ -258,55 +232,6 @@ class SamplingCapability(MCPModel): """ -class TasksListCapability(MCPModel): - """Capability for tasks listing operations.""" - - -class TasksCancelCapability(MCPModel): - """Capability for tasks cancel operations.""" - - -class TasksCreateMessageCapability(MCPModel): - """Capability for tasks create messages.""" - - -class TasksSamplingCapability(MCPModel): - """Capability for tasks sampling operations.""" - - create_message: TasksCreateMessageCapability | None = None - - -class TasksCreateElicitationCapability(MCPModel): - """Capability for tasks create elicitation operations.""" - - -class TasksElicitationCapability(MCPModel): - """Capability for tasks elicitation operations.""" - - create: TasksCreateElicitationCapability | None = None - - -class ClientTasksRequestsCapability(MCPModel): - """Capability for tasks requests operations.""" - - sampling: TasksSamplingCapability | None = None - - elicitation: TasksElicitationCapability | None = None - - -class ClientTasksCapability(MCPModel): - """Capability for client tasks operations.""" - - list: TasksListCapability | None = None - """Whether this client supports tasks/list.""" - - cancel: TasksCancelCapability | None = None - """Whether this client supports tasks/cancel.""" - - requests: ClientTasksRequestsCapability | None = None - """Specifies which request types can be augmented with tasks.""" - - class ClientCapabilities(MCPModel): """Capabilities a client may support.""" @@ -321,8 +246,6 @@ class ClientCapabilities(MCPModel): """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" - tasks: ClientTasksCapability | None = None - """Present if the client supports task-augmented requests.""" class PromptsCapability(MCPModel): @@ -356,30 +279,6 @@ class CompletionsCapability(MCPModel): """Capability for completions operations.""" -class TasksCallCapability(MCPModel): - """Capability for tasks call operations.""" - - -class TasksToolsCapability(MCPModel): - """Capability for tasks tools operations.""" - - call: TasksCallCapability | None = None - - -class ServerTasksRequestsCapability(MCPModel): - """Capability for tasks requests operations.""" - - tools: TasksToolsCapability | None = None - - -class ServerTasksCapability(MCPModel): - """Capability for server tasks operations.""" - - list: TasksListCapability | None = None - cancel: TasksCancelCapability | None = None - requests: ServerTasksRequestsCapability | None = None - - class ServerCapabilities(MCPModel): """Capabilities that a server may support.""" @@ -401,146 +300,6 @@ class ServerCapabilities(MCPModel): completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" - tasks: ServerTasksCapability | None = None - """Present if the server supports task-augmented requests.""" - - -TaskStatus = Literal["working", "input_required", "completed", "failed", "cancelled"] - -# Task status constants -TASK_STATUS_WORKING: Final[Literal["working"]] = "working" -TASK_STATUS_INPUT_REQUIRED: Final[Literal["input_required"]] = "input_required" -TASK_STATUS_COMPLETED: Final[Literal["completed"]] = "completed" -TASK_STATUS_FAILED: Final[Literal["failed"]] = "failed" -TASK_STATUS_CANCELLED: Final[Literal["cancelled"]] = "cancelled" - - -class RelatedTaskMetadata(MCPModel): - """Metadata for associating messages with a task. - - Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. - """ - - task_id: str - """The task identifier this message is associated with.""" - - -class Task(MCPModel): - """Data associated with a task.""" - - task_id: str - """The task identifier.""" - - status: TaskStatus - """Current task state.""" - - status_message: str | None = None - """Optional human-readable message describing the current task state. - - This can provide context for any status, including: - - Reasons for "cancelled" status - - Summaries for "completed" status - - Diagnostic information for "failed" status (e.g., error details, what went wrong) - """ - - created_at: datetime # Pydantic will enforce ISO 8601 and re-serialize as a string later - """ISO 8601 timestamp when the task was created.""" - - last_updated_at: datetime - """ISO 8601 timestamp when the task was last updated.""" - - ttl: Annotated[int, Field(strict=True)] | None - """Actual retention duration from creation in milliseconds, null for unlimited.""" - - poll_interval: Annotated[int, Field(strict=True)] | None = None - """Suggested polling interval in milliseconds.""" - - -class CreateTaskResult(Result): - """A response to a task-augmented request.""" - - task: Task - - -class GetTaskRequestParams(RequestParams): - task_id: str - """The task identifier to query.""" - - -class GetTaskRequest(Request[GetTaskRequestParams, Literal["tasks/get"]]): - """A request to retrieve the state of a task.""" - - method: Literal["tasks/get"] = "tasks/get" - - params: GetTaskRequestParams - - -class GetTaskResult(Result, Task): - """The response to a tasks/get request.""" - - -class GetTaskPayloadRequestParams(RequestParams): - task_id: str - """The task identifier to retrieve results for.""" - - -class GetTaskPayloadRequest(Request[GetTaskPayloadRequestParams, Literal["tasks/result"]]): - """A request to retrieve the result of a completed task.""" - - method: Literal["tasks/result"] = "tasks/result" - params: GetTaskPayloadRequestParams - - -class GetTaskPayloadResult(Result): - """The response to a tasks/result request. - - The structure matches the result type of the original request. - For example, a tools/call task would return the CallToolResult structure. - """ - - model_config = ConfigDict(extra="allow", alias_generator=to_camel, populate_by_name=True) - - -class CancelTaskRequestParams(RequestParams): - task_id: str - """The task identifier to cancel.""" - - -class CancelTaskRequest(Request[CancelTaskRequestParams, Literal["tasks/cancel"]]): - """A request to cancel a task.""" - - method: Literal["tasks/cancel"] = "tasks/cancel" - params: CancelTaskRequestParams - - -class CancelTaskResult(Result, Task): - """The response to a tasks/cancel request.""" - - -class ListTasksRequest(PaginatedRequest[Literal["tasks/list"]]): - """A request to retrieve a list of tasks.""" - - method: Literal["tasks/list"] = "tasks/list" - - -class ListTasksResult(PaginatedResult): - """The response to a tasks/list request.""" - - tasks: list[Task] - - -class TaskStatusNotificationParams(NotificationParams, Task): - """Parameters for a `notifications/tasks/status` notification.""" - - -class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): - """An optional notification from the receiver to the requestor, informing them that a task's status has changed. - Receivers are not required to send these notifications. - """ - - method: Literal["notifications/tasks/status"] = "notifications/tasks/status" - params: TaskStatusNotificationParams - class InitializeRequestParams(RequestParams): """Parameters for the initialize request.""" @@ -1133,23 +892,6 @@ class ToolAnnotations(MCPModel): """ -class ToolExecution(MCPModel): - """Execution-related properties for a tool.""" - - task_support: TaskExecutionMode | None = None - """ - Indicates whether this tool supports task-augmented execution. - This allows clients to handle long-running operations through polling - the task system. - - - "forbidden": Tool does not support task-augmented execution (default when absent) - - "optional": Tool may support task-augmented execution - - "required": Tool requires task-augmented execution - - Default: "forbidden" - """ - - class Tool(BaseMetadata): """Definition for a tool the client can call.""" @@ -1172,8 +914,6 @@ class Tool(BaseMetadata): for notes on _meta usage. """ - execution: ToolExecution | None = None - class ListToolsResult(PaginatedResult): """The server's response to a tools/list request from the client.""" @@ -1554,8 +1294,6 @@ class CancelledNotificationParams(NotificationParams): The ID of the request to cancel. This MUST correspond to the ID of a request previously issued in the same direction. - This MUST be provided for cancelling non-task requests. - This MUST NOT be used for cancelling tasks (use the `tasks/cancel` request instead). """ reason: str | None = None """An optional string describing the reason for the cancellation.""" @@ -1607,20 +1345,12 @@ class ElicitCompleteNotification( | UnsubscribeRequest | CallToolRequest | ListToolsRequest - | GetTaskRequest - | GetTaskPayloadRequest - | ListTasksRequest - | CancelTaskRequest ) client_request_adapter = TypeAdapter[ClientRequest](ClientRequest) ClientNotification = ( - CancelledNotification - | ProgressNotification - | InitializedNotification - | RootsListChangedNotification - | TaskStatusNotification + CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification ) client_notification_adapter = TypeAdapter[ClientNotification](ClientNotification) @@ -1716,31 +1446,11 @@ class ElicitationRequiredErrorData(MCPModel): """List of URL mode elicitations that must be completed.""" -ClientResult = ( - EmptyResult - | CreateMessageResult - | CreateMessageResultWithTools - | ListRootsResult - | ElicitResult - | GetTaskResult - | GetTaskPayloadResult - | ListTasksResult - | CancelTaskResult - | CreateTaskResult -) +ClientResult = EmptyResult | CreateMessageResult | CreateMessageResultWithTools | ListRootsResult | ElicitResult client_result_adapter = TypeAdapter[ClientResult](ClientResult) -ServerRequest = ( - PingRequest - | CreateMessageRequest - | ListRootsRequest - | ElicitRequest - | GetTaskRequest - | GetTaskPayloadRequest - | ListTasksRequest - | CancelTaskRequest -) +ServerRequest = PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest server_request_adapter = TypeAdapter[ServerRequest](ServerRequest) @@ -1753,7 +1463,6 @@ class ElicitationRequiredErrorData(MCPModel): | ToolListChangedNotification | PromptListChangedNotification | ElicitCompleteNotification - | TaskStatusNotification ) server_notification_adapter = TypeAdapter[ServerNotification](ServerNotification) @@ -1769,10 +1478,5 @@ class ElicitationRequiredErrorData(MCPModel): | ReadResourceResult | CallToolResult | ListToolsResult - | GetTaskResult - | GetTaskPayloadResult - | ListTasksResult - | CancelTaskResult - | CreateTaskResult ) server_result_adapter = TypeAdapter[ServerResult](ServerResult) diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/__init__.py b/tests/experimental/tasks/__init__.py deleted file mode 100644 index 6e8649d283..0000000000 --- a/tests/experimental/tasks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for MCP task support.""" diff --git a/tests/experimental/tasks/client/__init__.py b/tests/experimental/tasks/client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py deleted file mode 100644 index 1ea2199e8c..0000000000 --- a/tests/experimental/tasks/client/test_capabilities.py +++ /dev/null @@ -1,312 +0,0 @@ -"""Tests for client task capabilities declaration during initialization.""" - -import anyio -import pytest - -from mcp import ClientCapabilities, types -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - Implementation, - InitializeRequest, - InitializeResult, - JSONRPCRequest, - JSONRPCResponse, - ServerCapabilities, - client_request_adapter, -) - - -@pytest.mark.anyio -async def test_client_capabilities_without_tasks(): - """Test that tasks capability is None when not provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities = None - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability is None when not provided - assert received_capabilities is not None - assert received_capabilities.tasks is None - - -@pytest.mark.anyio -async def test_client_capabilities_with_tasks(): - """Test that tasks capability is properly set when handlers are provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define custom handlers to trigger capability building (never actually called) - async def my_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: - raise NotImplementedError - - async def my_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Create handlers container - task_handlers = ExperimentalTaskHandlers( - list_tasks=my_list_tasks_handler, - cancel_task=my_cancel_task_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability is properly set from handlers - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert isinstance(received_capabilities.tasks, types.ClientTasksCapability) - assert received_capabilities.tasks.list is not None - assert received_capabilities.tasks.cancel is not None - - -@pytest.mark.anyio -async def test_client_capabilities_auto_built_from_handlers(): - """Test that tasks capability is automatically built from provided handlers.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define custom handlers (not defaults) - async def my_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: - raise NotImplementedError - - async def my_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Provide handlers via ExperimentalTaskHandlers - task_handlers = ExperimentalTaskHandlers( - list_tasks=my_list_tasks_handler, - cancel_task=my_cancel_task_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability was auto-built from handlers - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert received_capabilities.tasks.list is not None - assert received_capabilities.tasks.cancel is not None - # requests should be None since we didn't provide task-augmented handlers - assert received_capabilities.tasks.requests is None - - -@pytest.mark.anyio -async def test_client_capabilities_with_task_augmented_handlers(): - """Test that requests capability is built when augmented handlers are provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define task-augmented handler - async def my_augmented_sampling_handler( - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Provide task-augmented sampling handler - task_handlers = ExperimentalTaskHandlers( - augmented_sampling=my_augmented_sampling_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability includes requests.sampling - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert received_capabilities.tasks.requests is not None - assert received_capabilities.tasks.requests.sampling is not None - assert received_capabilities.tasks.requests.elicitation is None # Not provided diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py deleted file mode 100644 index 137ff80106..0000000000 --- a/tests/experimental/tasks/client/test_handlers.py +++ /dev/null @@ -1,874 +0,0 @@ -"""Tests for client-side task management handlers (server -> client requests). - -These tests verify that clients can handle task-related requests from servers: -- GetTaskRequest - server polling client's task status -- GetTaskPayloadRequest - server getting result from client's task -- ListTasksRequest - server listing client's tasks -- CancelTaskRequest - server cancelling client's task - -This is the inverse of the existing tests in test_tasks.py, which test -client -> server task requests. -""" - -from collections.abc import AsyncIterator -from dataclasses import dataclass - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - -from mcp import types -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ( - CancelTaskRequest, - CancelTaskRequestParams, - CancelTaskResult, - ClientResult, - CreateMessageRequest, - CreateMessageRequestParams, - CreateMessageResult, - CreateTaskResult, - ElicitRequest, - ElicitRequestFormParams, - ElicitRequestParams, - ElicitResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequest, - GetTaskRequestParams, - GetTaskResult, - ListTasksRequest, - ListTasksResult, - SamplingMessage, - ServerNotification, - ServerRequest, - TaskMetadata, - TextContent, -) - -# Buffer size for test streams -STREAM_BUFFER_SIZE = 10 - - -@dataclass -class ClientTestStreams: - """Bidirectional message streams for client/server communication in tests.""" - - server_send: MemoryObjectSendStream[SessionMessage] - server_receive: MemoryObjectReceiveStream[SessionMessage] - client_send: MemoryObjectSendStream[SessionMessage] - client_receive: MemoryObjectReceiveStream[SessionMessage] - - -@pytest.fixture -async def client_streams() -> AsyncIterator[ClientTestStreams]: - """Create bidirectional message streams for client tests. - - Automatically closes all streams after the test completes. - """ - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage]( - STREAM_BUFFER_SIZE - ) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage]( - STREAM_BUFFER_SIZE - ) - - streams = ClientTestStreams( - server_send=server_to_client_send, - server_receive=client_to_server_receive, - client_send=client_to_server_send, - client_receive=server_to_client_receive, - ) - - yield streams - - # Cleanup - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -async def _default_message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, -) -> None: - """Default message handler that ignores messages (tests handle them explicitly).""" - ... - - -@pytest.mark.anyio -async def test_client_handles_get_task_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to GetTaskRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - received_task_id: str | None = None - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - nonlocal received_task_id - received_task_id = params.task_id - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123") - - task_handlers = ExperimentalTaskHandlers(get_task=get_task_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="test-task-123")) - request = types.JSONRPCRequest(jsonrpc="2.0", id="req-1", **typed_request.model_dump(by_alias=True)) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - assert response.id == "req-1" - - result = GetTaskResult.model_validate(response.result) - assert result.task_id == "test-task-123" - assert result.status == "working" - assert received_task_id == "test-task-123" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_get_task_result_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to GetTaskPayloadRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, types.CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-456") - await store.store_result( - "test-task-456", - types.CallToolResult(content=[TextContent(type="text", text="Task completed successfully!")]), - ) - await store.update_task("test-task-456", status="completed") - - task_handlers = ExperimentalTaskHandlers(get_task_result=get_task_result_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-456")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-2", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - assert isinstance(response.result, dict) - result_dict = response.result - assert "content" in result_dict - assert len(result_dict["content"]) == 1 - assert result_dict["content"][0]["text"] == "Task completed successfully!" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_list_tasks_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to ListTasksRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> ListTasksResult | ErrorData: - cursor = params.cursor if params else None - tasks_list, next_cursor = await store.list_tasks(cursor=cursor) - return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - - await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") - await store.create_task(TaskMetadata(ttl=60000), task_id="task-2") - - task_handlers = ExperimentalTaskHandlers(list_tasks=list_tasks_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = ListTasksRequest() - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-3", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - result = ListTasksResult.model_validate(response.result) - assert len(result.tasks) == 2 - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_cancel_task_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to CancelTaskRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def cancel_task_handler( - context: RequestContext[ClientSession], - params: CancelTaskRequestParams, - ) -> CancelTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - await store.update_task(params.task_id, status="cancelled") - updated = await store.get_task(params.task_id) - assert updated is not None - return CancelTaskResult( - task_id=updated.task_id, - status=updated.status, - created_at=updated.created_at, - last_updated_at=updated.last_updated_at, - ttl=updated.ttl, - ) - - await store.create_task(TaskMetadata(ttl=60000), task_id="task-to-cancel") - - task_handlers = ExperimentalTaskHandlers(cancel_task=cancel_task_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = CancelTaskRequest(params=CancelTaskRequestParams(task_id="task-to-cancel")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-4", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - result = CancelTaskResult.model_validate(response.result) - assert result.task_id == "task-to-cancel" - assert result.status == "cancelled" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_task_augmented_sampling(client_streams: ClientTestStreams) -> None: - """Test that client can handle task-augmented sampling request from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - sampling_completed = Event() - created_task_id: list[str | None] = [None] - background_tg: list[TaskGroup | None] = [None] - - async def task_augmented_sampling_callback( - context: RequestContext[ClientSession], - params: CreateMessageRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - task = await store.create_task(task_metadata) - created_task_id[0] = task.task_id - - async def do_sampling() -> None: - result = CreateMessageResult( - role="assistant", - content=TextContent(type="text", text="Sampled response"), - model="test-model", - stop_reason="endTurn", - ) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - sampling_completed.set() - - assert background_tg[0] is not None - background_tg[0].start_soon(do_sampling) - return CreateTaskResult(task=task) - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CreateMessageResult) - return GetTaskPayloadResult(**result.model_dump()) - - task_handlers = ExperimentalTaskHandlers( - augmented_sampling=task_augmented_sampling_callback, - get_task=get_task_handler, - get_task_result=get_task_result_handler, - ) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - background_tg[0] = tg - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Step 1: Server sends task-augmented CreateMessageRequest - typed_request = CreateMessageRequest( - params=CreateMessageRequestParams( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-sampling", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - # Step 2: Client responds with CreateTaskResult - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - task_result = CreateTaskResult.model_validate(response.result) - task_id = task_result.task.task_id - assert task_id == created_task_id[0] - - # Step 3: Wait for background sampling - await sampling_completed.wait() - - # Step 4: Server polls task status - typed_poll = GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)) - poll_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-poll", - **typed_poll.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(poll_request)) - - poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message - assert isinstance(poll_response, types.JSONRPCResponse) - - status = GetTaskResult.model_validate(poll_response.result) - assert status.status == "completed" - - # Step 5: Server gets result - typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)) - result_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_result_req.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(result_request)) - - result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message - assert isinstance(result_response, types.JSONRPCResponse) - - assert isinstance(result_response.result, dict) - assert result_response.result["role"] == "assistant" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_task_augmented_elicitation(client_streams: ClientTestStreams) -> None: - """Test that client can handle task-augmented elicitation request from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - elicitation_completed = Event() - created_task_id: list[str | None] = [None] - background_tg: list[TaskGroup | None] = [None] - - async def task_augmented_elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult | ErrorData: - task = await store.create_task(task_metadata) - created_task_id[0] = task.task_id - - async def do_elicitation() -> None: - # Simulate user providing elicitation response - result = ElicitResult(action="accept", content={"name": "Test User"}) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - elicitation_completed.set() - - assert background_tg[0] is not None - background_tg[0].start_soon(do_elicitation) - return CreateTaskResult(task=task) - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, ElicitResult) - return GetTaskPayloadResult(**result.model_dump()) - - task_handlers = ExperimentalTaskHandlers( - augmented_elicitation=task_augmented_elicitation_callback, - get_task=get_task_handler, - get_task_result=get_task_result_handler, - ) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - background_tg[0] = tg - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Step 1: Server sends task-augmented ElicitRequest - typed_request = ElicitRequest( - params=ElicitRequestFormParams( - message="What is your name?", - requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-elicit", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - # Step 2: Client responds with CreateTaskResult - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - task_result = CreateTaskResult.model_validate(response.result) - task_id = task_result.task.task_id - assert task_id == created_task_id[0] - - # Step 3: Wait for background elicitation - await elicitation_completed.wait() - - # Step 4: Server polls task status - typed_poll = GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)) - poll_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-poll", - **typed_poll.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(poll_request)) - - poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message - assert isinstance(poll_response, types.JSONRPCResponse) - - status = GetTaskResult.model_validate(poll_response.result) - assert status.status == "completed" - - # Step 5: Server gets result - typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)) - result_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_result_req.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(result_request)) - - result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message - assert isinstance(result_response, types.JSONRPCResponse) - - # Verify the elicitation result - assert isinstance(result_response.result, dict) - assert result_response.result["action"] == "accept" - assert result_response.result["content"] == {"name": "Test User"} - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error when no handler is registered for task request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-unhandled", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert ( - "not supported" in response.error.message.lower() - or "method not found" in response.error.message.lower() - ) - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_result_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/result request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_list_tasks_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/list request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = ListTasksRequest() - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-list", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_cancel_task_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/cancel request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = CancelTaskRequest(params=CancelTaskRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-cancel", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_augmented_sampling(client_streams: ClientTestStreams) -> None: - """Test that client returns error for task-augmented sampling without handler.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - # No task handlers provided - uses defaults - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Send task-augmented sampling request - typed_request = CreateMessageRequest( - params=CreateMessageRequestParams( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-sampling", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_augmented_elicitation( - client_streams: ClientTestStreams, -) -> None: - """Test that client returns error for task-augmented elicitation without handler.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - # No task handlers provided - uses defaults - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Send task-augmented elicitation request - typed_request = ElicitRequest( - params=ElicitRequestFormParams( - message="What is your name?", - requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-elicit", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() diff --git a/tests/experimental/tasks/client/test_poll_task.py b/tests/experimental/tasks/client/test_poll_task.py deleted file mode 100644 index 5e3158d955..0000000000 --- a/tests/experimental/tasks/client/test_poll_task.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Tests for poll_task async iterator.""" - -from collections.abc import Callable, Coroutine -from datetime import datetime, timezone -from typing import Any -from unittest.mock import AsyncMock - -import pytest - -from mcp.client.experimental.tasks import ExperimentalClientFeatures -from mcp.types import GetTaskResult, TaskStatus - - -def make_task_result( - status: TaskStatus = "working", - poll_interval: int = 0, - task_id: str = "test-task", - status_message: str | None = None, -) -> GetTaskResult: - """Create GetTaskResult with sensible defaults.""" - now = datetime.now(timezone.utc) - return GetTaskResult( - task_id=task_id, - status=status, - status_message=status_message, - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=poll_interval, - ) - - -def make_status_sequence( - *statuses: TaskStatus, - task_id: str = "test-task", -) -> Callable[[str], Coroutine[Any, Any, GetTaskResult]]: - """Create mock get_task that returns statuses in sequence.""" - status_iter = iter(statuses) - - async def mock_get_task(tid: str) -> GetTaskResult: - return make_task_result(status=next(status_iter), task_id=tid) - - return mock_get_task - - -@pytest.fixture -def mock_session() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def features(mock_session: AsyncMock) -> ExperimentalClientFeatures: - return ExperimentalClientFeatures(mock_session) - - -@pytest.mark.anyio -async def test_poll_task_yields_until_completed(features: ExperimentalClientFeatures) -> None: - """poll_task yields each status until terminal.""" - features.get_task = make_status_sequence("working", "working", "completed") # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == ["working", "working", "completed"] - - -@pytest.mark.anyio -@pytest.mark.parametrize("terminal_status", ["completed", "failed", "cancelled"]) -async def test_poll_task_exits_on_terminal(features: ExperimentalClientFeatures, terminal_status: TaskStatus) -> None: - """poll_task exits immediately when task is already terminal.""" - features.get_task = make_status_sequence(terminal_status) # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == [terminal_status] - - -@pytest.mark.anyio -async def test_poll_task_continues_through_input_required(features: ExperimentalClientFeatures) -> None: - """poll_task yields input_required and continues (non-terminal).""" - features.get_task = make_status_sequence("working", "input_required", "working", "completed") # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == ["working", "input_required", "working", "completed"] - - -@pytest.mark.anyio -async def test_poll_task_passes_task_id(features: ExperimentalClientFeatures) -> None: - """poll_task passes correct task_id to get_task.""" - received_ids: list[str] = [] - - async def mock_get_task(task_id: str) -> GetTaskResult: - received_ids.append(task_id) - return make_task_result(status="completed", task_id=task_id) - - features.get_task = mock_get_task # type: ignore[method-assign] - - _ = [s async for s in features.poll_task("my-task-123")] - - assert received_ids == ["my-task-123"] - - -@pytest.mark.anyio -async def test_poll_task_yields_full_result(features: ExperimentalClientFeatures) -> None: - """poll_task yields complete GetTaskResult objects.""" - - async def mock_get_task(task_id: str) -> GetTaskResult: - return make_task_result( - status="completed", - task_id=task_id, - status_message="All done!", - ) - - features.get_task = mock_get_task # type: ignore[method-assign] - - results = [r async for r in features.poll_task("test-task")] - - assert len(results) == 1 - assert results[0].status == "completed" - assert results[0].status_message == "All done!" - assert results[0].task_id == "test-task" diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py deleted file mode 100644 index 613c794ebf..0000000000 --- a/tests/experimental/tasks/client/test_tasks.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Tests for the experimental client task methods (session.experimental).""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.shared.experimental.tasks.helpers import task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CancelTaskRequestParams, - CancelTaskResult, - CreateTaskResult, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - TaskMetadata, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -@dataclass -class AppContext: - """Application context passed via lifespan_context.""" - - task_group: TaskGroup - store: InMemoryTaskStore - task_done_events: dict[str, Event] = field(default_factory=lambda: {}) - - -async def _handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None -) -> ListToolsResult: - raise NotImplementedError - - -async def _handle_call_tool_with_done_event( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams, *, result_text: str = "Done" -) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): - @asynccontextmanager - async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: - async with anyio.create_task_group() as tg: - yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - - return app_lifespan - - -async def test_session_experimental_get_task() -> None: - """Test session.experimental.get_task() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool_with_done_event, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - # Create a task - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await task_done_events[task_id].wait() - - # Use session.experimental to get task status - task_status = await client.session.experimental.get_task(task_id) - - assert task_status.task_id == task_id - assert task_status.status == "completed" - - -async def test_session_experimental_get_task_result() -> None: - """Test session.experimental.get_task_result() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - return await _handle_call_tool_with_done_event(ctx, params, result_text="Task result content") - - async def handle_get_task_result( - ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks(on_task_result=handle_get_task_result) - - async with Client(server) as client: - # Create a task - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await task_done_events[task_id].wait() - - # Use TaskClient to get task result - task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Task result content" - - -async def test_session_experimental_list_tasks() -> None: - """Test TaskClient.list_tasks() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tasks( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListTasksResult: - app = ctx.lifespan_context - cursor = params.cursor if params else None - tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) - return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool_with_done_event, - ) - server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - - async with Client(server) as client: - # Create two tasks - for _ in range(2): - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - await task_done_events[create_result.task.task_id].wait() - - # Use TaskClient to list tasks - list_result = await client.session.experimental.list_tasks() - - assert len(list_result.tasks) == 2 - - -async def test_session_experimental_cancel_task() -> None: - """Test TaskClient.cancel_task() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_call_tool_no_work( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - # Don't start any work - task stays in "working" status - return CreateTaskResult(task=task) - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def handle_cancel_task( - ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams - ) -> CancelTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - await app.store.update_task(params.task_id, status="cancelled") - updated_task = await app.store.get_task(params.task_id) - assert updated_task is not None - return CancelTaskResult( - task_id=updated_task.task_id, - status=updated_task.status, - created_at=updated_task.created_at, - last_updated_at=updated_task.last_updated_at, - ttl=updated_task.ttl, - ) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=handle_call_tool_no_work, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task, on_cancel_task=handle_cancel_task) - - async with Client(server) as client: - # Create a task (but don't complete it) - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Verify task is working - status_before = await client.session.experimental.get_task(task_id) - assert status_before.status == "working" - - # Cancel the task - await client.session.experimental.cancel_task(task_id) - - # Verify task is cancelled - status_after = await client.session.experimental.get_task(task_id) - assert status_after.status == "cancelled" diff --git a/tests/experimental/tasks/server/__init__.py b/tests/experimental/tasks/server/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/server/test_context.py b/tests/experimental/tasks/server/test_context.py deleted file mode 100644 index a0f1a190d2..0000000000 --- a/tests/experimental/tasks/server/test_context.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Tests for TaskContext and helper functions.""" - -import pytest - -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.helpers import create_task_state, task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import CallToolResult, TaskMetadata, TextContent - - -@pytest.mark.anyio -async def test_task_context_properties() -> None: - """Test TaskContext basic properties.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - assert ctx.task_id == task.task_id - assert ctx.task.task_id == task.task_id - assert ctx.task.status == "working" - assert ctx.is_cancelled is False - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_update_status() -> None: - """Test TaskContext.update_status.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - await ctx.update_status("Processing step 1...") - - # Check status message was updated - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status_message == "Processing step 1..." - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_complete() -> None: - """Test TaskContext.complete.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - result = CallToolResult(content=[TextContent(type="text", text="Done!")]) - await ctx.complete(result) - - # Check task status - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status == "completed" - - # Check result is stored - stored_result = await store.get_result(task.task_id) - assert stored_result is not None - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_fail() -> None: - """Test TaskContext.fail.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - await ctx.fail("Something went wrong!") - - # Check task status - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status == "failed" - assert updated.status_message == "Something went wrong!" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_cancellation() -> None: - """Test TaskContext cancellation request.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - assert ctx.is_cancelled is False - - ctx.request_cancellation() - - assert ctx.is_cancelled is True - - store.cleanup() - - -def test_create_task_state_generates_id() -> None: - """create_task_state generates a unique task ID when none provided.""" - task1 = create_task_state(TaskMetadata(ttl=60000)) - task2 = create_task_state(TaskMetadata(ttl=60000)) - - assert task1.task_id != task2.task_id - - -def test_create_task_state_uses_provided_id() -> None: - """create_task_state uses the provided task ID.""" - task = create_task_state(TaskMetadata(ttl=60000), task_id="my-task-123") - assert task.task_id == "my-task-123" - - -def test_create_task_state_null_ttl() -> None: - """create_task_state handles null TTL.""" - task = create_task_state(TaskMetadata(ttl=None)) - assert task.ttl is None - - -def test_create_task_state_has_created_at() -> None: - """create_task_state sets createdAt timestamp.""" - task = create_task_state(TaskMetadata(ttl=60000)) - assert task.created_at is not None - - -@pytest.mark.anyio -async def test_task_execution_provides_context() -> None: - """task_execution provides a TaskContext for the task.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-test-1") - - async with task_execution("exec-test-1", store) as ctx: - assert ctx.task_id == "exec-test-1" - assert ctx.task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_auto_fails_on_exception() -> None: - """task_execution automatically fails task on unhandled exception.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-fail-1") - - async with task_execution("exec-fail-1", store): - raise RuntimeError("Oops!") - - # Task should be failed - failed_task = await store.get_task("exec-fail-1") - assert failed_task is not None - assert failed_task.status == "failed" - assert "Oops!" in (failed_task.status_message or "") - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_doesnt_fail_if_already_terminal() -> None: - """task_execution doesn't re-fail if task already terminal.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-term-1") - - async with task_execution("exec-term-1", store) as ctx: - # Complete the task first - await ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - # Then raise - shouldn't change status - raise RuntimeError("This shouldn't matter") - - # Task should remain completed - final_task = await store.get_task("exec-term-1") - assert final_task is not None - assert final_task.status == "completed" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_not_found() -> None: - """task_execution raises ValueError for non-existent task.""" - store = InMemoryTaskStore() - - with pytest.raises(ValueError, match="not found"): - async with task_execution("nonexistent", store): - ... diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py deleted file mode 100644 index b5b79033d0..0000000000 --- a/tests/experimental/tasks/server/test_integration.py +++ /dev/null @@ -1,247 +0,0 @@ -"""End-to-end integration tests for tasks functionality. - -These tests demonstrate the full task lifecycle: -1. Client sends task-augmented request (tools/call with task metadata) -2. Server creates task and returns CreateTaskResult immediately -3. Background work executes (using task_execution context manager) -4. Client polls with tasks/get -5. Client retrieves result with tasks/result -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.shared.experimental.tasks.helpers import task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CreateTaskResult, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - TaskMetadata, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -@dataclass -class AppContext: - """Application context passed via lifespan_context.""" - - task_group: TaskGroup - store: InMemoryTaskStore - task_done_events: dict[str, Event] = field(default_factory=lambda: {}) - - -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): - @asynccontextmanager - async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: - async with anyio.create_task_group() as tg: - yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - - return app_lifespan - - -async def test_task_lifecycle_with_task_execution() -> None: - """Test the complete task lifecycle using the task_execution pattern.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if params.name == "process_data" and ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.update_status("Processing input...") - input_value = (params.arguments or {}).get("input", "") - result_text = f"Processed: {input_value.upper()}" - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def handle_get_task_result( - ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - async def handle_list_tasks( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListTasksResult: - raise NotImplementedError - - server: Server[AppContext] = Server( - "test-tasks", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks( - on_get_task=handle_get_task, - on_task_result=handle_get_task_result, - on_list_tasks=handle_list_tasks, - ) - - async with Client(server) as client: - # Step 1: Send task-augmented tool call - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - assert isinstance(create_result, CreateTaskResult) - assert create_result.task.status == "working" - task_id = create_result.task.task_id - - # Step 2: Wait for task to complete - await task_done_events[task_id].wait() - - task_status = await client.session.experimental.get_task(task_id) - assert task_status.task_id == task_id - assert task_status.status == "completed" - - # Step 3: Retrieve the actual result - task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Processed: HELLO WORLD" - - -async def test_task_auto_fails_on_exception() -> None: - """Test that task_execution automatically fails the task on unhandled exception.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if params.name == "failing_task" and ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_failing_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.update_status("About to fail...") - raise RuntimeError("Something went wrong!") - # This line is reached because task_execution suppresses the exception - done_event.set() - - app.task_group.start_soon(do_failing_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - server: Server[AppContext] = Server( - "test-tasks-failure", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - # Send task request - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - task_id = create_result.task.task_id - - # Wait for task to complete (even though it fails) - await task_done_events[task_id].wait() - - # Check that task was auto-failed - task_status = await client.session.experimental.get_task(task_id) - - assert task_status.status == "failed" - assert task_status.status_message == "Something went wrong!" diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py deleted file mode 100644 index 027382e69e..0000000000 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ /dev/null @@ -1,367 +0,0 @@ -"""Tests for the simplified task API: enable_tasks() + run_task() - -This tests the recommended user flow: -1. server.experimental.enable_tasks() - one-line setup -2. ctx.experimental.run_task(work) - spawns work, returns CreateTaskResult -3. work function uses ServerTaskContext for elicit/create_message - -These are integration tests that verify the complete flow works end-to-end. -""" - -from unittest.mock import Mock - -import anyio -import pytest -from anyio import Event - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.request_context import Experimental -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel import NotificationOptions -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.types import ( - TASK_REQUIRED, - CallToolRequestParams, - CallToolResult, - CreateTaskResult, - GetTaskRequestParams, - GetTaskResult, - ListToolsResult, - PaginatedRequestParams, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -async def _handle_list_tools_simple_task( - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - raise NotImplementedError - - -async def test_run_task_basic_flow() -> None: - """Test the basic run_task flow without elicitation.""" - work_completed = Event() - received_meta: list[str | None] = [None] - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - if ctx.meta is not None: # pragma: no branch - received_meta[0] = ctx.meta.get("custom_field") - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Working...") - input_val = (params.arguments or {}).get("input", "default") - result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) - work_completed.set() - return result - - return await ctx.experimental.run_task(work) - - server = Server( - "test-run-task", - on_list_tools=_handle_list_tools_simple_task, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task( - "simple_task", - {"input": "hello"}, - meta={"custom_field": "test_value"}, - ) - - task_id = result.task.task_id - assert result.task.status == "working" - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - task_status = await client.session.experimental.get_task(task_id) - if task_status.status == "completed": # pragma: no branch - break - - assert received_meta[0] == "test_value" - - -async def test_run_task_auto_fails_on_exception() -> None: - """Test that run_task automatically fails the task when work raises.""" - work_failed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - work_failed.set() - raise RuntimeError("Something went wrong!") - - return await ctx.experimental.run_task(work) - - server = Server( - "test-run-task-fail", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("failing_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_failed.wait() - - with anyio.fail_after(5): - while True: - task_status = await client.session.experimental.get_task(task_id) - if task_status.status == "failed": # pragma: no branch - break - - assert "Something went wrong" in (task_status.status_message or "") - - -async def test_enable_tasks_auto_registers_handlers() -> None: - """Test that enable_tasks() auto-registers get_task, list_tasks, cancel_task handlers.""" - server = Server("test-enable-tasks") - - # Before enable_tasks, no task capabilities - caps_before = server.get_capabilities(NotificationOptions(), {}) - assert caps_before.tasks is None - - # Enable tasks - server.experimental.enable_tasks() - - # After enable_tasks, should have task capabilities - caps_after = server.get_capabilities(NotificationOptions(), {}) - assert caps_after.tasks is not None - assert caps_after.tasks.list is not None - assert caps_after.tasks.cancel is not None - assert caps_after.tasks.requests is not None - assert caps_after.tasks.requests.tools is not None - assert caps_after.tasks.requests.tools.call is not None - - -async def test_enable_tasks_with_custom_store_and_queue() -> None: - """Test that enable_tasks() uses provided store and queue instead of defaults.""" - server = Server("test-custom-store-queue") - - custom_store = InMemoryTaskStore() - custom_queue = InMemoryTaskMessageQueue() - - task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue) - - assert task_support.store is custom_store - assert task_support.queue is custom_queue - - -async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None: - """Test that enable_tasks() doesn't override already-registered handlers.""" - server = Server("test-custom-handlers") - - # Register custom handlers via enable_tasks kwargs - async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - server.experimental.enable_tasks(on_get_task=custom_get_task) - - # Verify handler is registered - assert server._has_handler("tasks/get") - assert server._has_handler("tasks/list") - assert server._has_handler("tasks/cancel") - assert server._has_handler("tasks/result") - - -async def test_run_task_without_enable_tasks_raises() -> None: - """Test that run_task raises when enable_tasks() wasn't called.""" - experimental = Experimental( - task_metadata=None, - _client_capabilities=None, - _session=None, - _task_support=None, # Not enabled - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Task support not enabled"): - await experimental.run_task(work) - - -async def test_task_support_task_group_before_run_raises() -> None: - """Test that accessing task_group before run() raises RuntimeError.""" - task_support = TaskSupport.in_memory() - - with pytest.raises(RuntimeError, match="TaskSupport not running"): - _ = task_support.task_group - - -async def test_run_task_without_session_raises() -> None: - """Test that run_task raises when session is not available.""" - task_support = TaskSupport.in_memory() - - experimental = Experimental( - task_metadata=None, - _client_capabilities=None, - _session=None, # No session - _task_support=task_support, - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Session not available"): - await experimental.run_task(work) - - -async def test_run_task_without_task_metadata_raises() -> None: - """Test that run_task raises when request is not task-augmented.""" - task_support = TaskSupport.in_memory() - mock_session = Mock() - - experimental = Experimental( - task_metadata=None, # Not a task-augmented request - _client_capabilities=None, - _session=mock_session, - _task_support=task_support, - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Request is not task-augmented"): - await experimental.run_task(work) - - -async def test_run_task_with_model_immediate_response() -> None: - """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" - work_completed = Event() - immediate_response_text = "Processing your request..." - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="Done")]) - - return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) - - server = Server( - "test-run-task-immediate", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("task_with_immediate", {}) - - assert result.meta is not None - assert "io.modelcontextprotocol/model-immediate-response" in result.meta - assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text - - with anyio.fail_after(5): - await work_completed.wait() - - -async def test_run_task_doesnt_complete_if_already_terminal() -> None: - """Test that run_task doesn't auto-complete if work manually completed the task.""" - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - manual_result = CallToolResult(content=[TextContent(type="text", text="Manually completed")]) - await task.complete(manual_result, notify=False) - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="This should be ignored")]) - - return await ctx.experimental.run_task(work) - - server = Server( - "test-already-complete", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("manual_complete_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - status = await client.session.experimental.get_task(task_id) - if status.status == "completed": # pragma: no branch - break - - -async def test_run_task_doesnt_fail_if_already_terminal() -> None: - """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.fail("Manually failed", notify=False) - work_completed.set() - raise RuntimeError("This error should not change status") - - return await ctx.experimental.run_task(work) - - server = Server( - "test-already-failed", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("manual_cancel_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - status = await client.session.experimental.get_task(task_id) - if status.status == "failed": # pragma: no branch - break - - assert status.status_message == "Manually failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py deleted file mode 100644 index 6a28b274ea..0000000000 --- a/tests/experimental/tasks/server/test_server.py +++ /dev/null @@ -1,797 +0,0 @@ -"""Tests for server-side task support (handlers, capabilities, integration).""" - -from datetime import datetime, timezone -from typing import Any - -import anyio -import pytest - -from mcp import Client -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.response_router import ResponseRouter -from mcp.shared.session import RequestResponder -from mcp.types import ( - INVALID_REQUEST, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CancelTaskRequestParams, - CancelTaskResult, - ClientResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - JSONRPCError, - JSONRPCNotification, - JSONRPCResponse, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - SamplingMessage, - ServerCapabilities, - ServerNotification, - ServerRequest, - Task, - TaskMetadata, - TextContent, - Tool, - ToolExecution, -) - -pytestmark = pytest.mark.anyio - - -async def test_list_tasks_handler() -> None: - """Test that experimental list_tasks handler works via Client.""" - now = datetime.now(timezone.utc) - test_tasks = [ - Task(task_id="task-1", status="working", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), - Task(task_id="task-2", status="completed", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), - ] - - async def handle_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - return ListTasksResult(tasks=test_tasks) - - server = Server("test") - server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - - async with Client(server) as client: - result = await client.session.experimental.list_tasks() - assert len(result.tasks) == 2 - assert result.tasks[0].task_id == "task-1" - assert result.tasks[1].task_id == "task-2" - - -async def test_get_task_handler() -> None: - """Test that experimental get_task handler works via Client.""" - - async def handle_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - now = datetime.now(timezone.utc) - return GetTaskResult( - task_id=params.task_id, - status="working", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ) - - server = Server("test") - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - result = await client.session.experimental.get_task("test-task-123") - assert result.task_id == "test-task-123" - assert result.status == "working" - - -async def test_get_task_result_handler() -> None: - """Test that experimental get_task_result handler works via Client.""" - - async def handle_get_task_result( - ctx: ServerRequestContext, params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - return GetTaskPayloadResult() - - server = Server("test") - server.experimental.enable_tasks(on_task_result=handle_get_task_result) - - async with Client(server) as client: - result = await client.session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-123")), - GetTaskPayloadResult, - ) - assert isinstance(result, GetTaskPayloadResult) - - -async def test_cancel_task_handler() -> None: - """Test that experimental cancel_task handler works via Client.""" - - async def handle_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - now = datetime.now(timezone.utc) - return CancelTaskResult( - task_id=params.task_id, - status="cancelled", - created_at=now, - last_updated_at=now, - ttl=60000, - ) - - server = Server("test") - server.experimental.enable_tasks(on_cancel_task=handle_cancel_task) - - async with Client(server) as client: - result = await client.session.experimental.cancel_task("test-task-123") - assert result.task_id == "test-task-123" - assert result.status == "cancelled" - - -async def test_server_capabilities_include_tasks() -> None: - """Test that server capabilities include tasks when handlers are registered.""" - server = Server("test") - - async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - async def noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - raise NotImplementedError - - server.experimental.enable_tasks(on_list_tasks=noop_list_tasks, on_cancel_task=noop_cancel_task) - - capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) - - assert capabilities.tasks is not None - assert capabilities.tasks.list is not None - assert capabilities.tasks.cancel is not None - assert capabilities.tasks.requests is not None - assert capabilities.tasks.requests.tools is not None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -async def test_server_capabilities_partial_tasks() -> None: # pragma: no cover - """Test capabilities with only some task handlers registered.""" - server = Server("test") - - async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - # Only list_tasks registered, not cancel_task - server.experimental.enable_tasks(on_list_tasks=noop_list_tasks) - - capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) - - assert capabilities.tasks is not None - assert capabilities.tasks.list is not None - assert capabilities.tasks.cancel is None # Not registered - - -async def test_tool_with_task_execution_metadata() -> None: - """Test that tools can declare task execution mode.""" - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="quick_tool", - description="Fast tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_FORBIDDEN), - ), - Tool( - name="long_tool", - description="Long running tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ), - Tool( - name="flexible_tool", - description="Can be either", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ), - ] - ) - - server = Server("test", on_list_tools=handle_list_tools) - - async with Client(server) as client: - result = await client.list_tools() - tools = result.tools - - assert tools[0].execution is not None - assert tools[0].execution.task_support == TASK_FORBIDDEN - assert tools[1].execution is not None - assert tools[1].execution.task_support == TASK_REQUIRED - assert tools[2].execution is not None - assert tools[2].execution.task_support == TASK_OPTIONAL - - -async def test_task_metadata_in_call_tool_request() -> None: - """Test that task metadata is accessible via ctx when calling a tool.""" - captured_task_metadata: TaskMetadata | None = None - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - nonlocal captured_task_metadata - captured_task_metadata = ctx.experimental.task_metadata - return CallToolResult(content=[TextContent(type="text", text="done")]) - - server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - async with Client(server) as client: - # Call tool with task metadata - await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CallToolResult, - ) - - assert captured_task_metadata is not None - assert captured_task_metadata.ttl == 60000 - - -async def test_task_metadata_is_task_property() -> None: - """Test that ctx.experimental.is_task works correctly.""" - is_task_values: list[bool] = [] - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - is_task_values.append(ctx.experimental.is_task) - return CallToolResult(content=[TextContent(type="text", text="done")]) - - server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - async with Client(server) as client: - # Call without task metadata - await client.session.send_request( - CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), - CallToolResult, - ) - - # Call with task metadata - await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), - ), - CallToolResult, - ) - - assert len(is_task_values) == 2 - assert is_task_values[0] is False # First call without task - assert is_task_values[1] is True # Second call with task - - -async def test_update_capabilities_no_handlers() -> None: - """Test that update_capabilities returns early when no task handlers are registered.""" - server = Server("test-no-handlers") - _ = server.experimental - - caps = server.get_capabilities(NotificationOptions(), {}) - assert caps.tasks is None - - -async def test_update_capabilities_partial_handlers() -> None: - """Test that update_capabilities skips list/cancel when only tasks/get is registered.""" - server = Server("test-partial") - # Access .experimental to create the ExperimentalHandlers instance - exp = server.experimental - # Second access returns the same cached instance - assert server.experimental is exp - - async def noop_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - server._add_request_handler("tasks/get", noop_get) - - caps = server.get_capabilities(NotificationOptions(), {}) - assert caps.tasks is not None - assert caps.tasks.list is None - assert caps.tasks.cancel is None - - -async def test_default_task_handlers_via_enable_tasks() -> None: - """Test that enable_tasks() auto-registers working default handlers.""" - server = Server("test-default-handlers") - task_support = server.experimental.enable_tasks() - store = task_support.store - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server() -> None: - async with task_support.run(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - task_support.configure_session(server_session) - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, {}, False) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task directly in the store for testing - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Test list_tasks (default handler) - list_result = await client_session.experimental.list_tasks() - assert len(list_result.tasks) == 1 - assert list_result.tasks[0].task_id == task.task_id - - # Test get_task (default handler - found) - get_result = await client_session.experimental.get_task(task.task_id) - assert get_result.task_id == task.task_id - assert get_result.status == "working" - - # Test get_task (default handler - not found path) - with pytest.raises(MCPError, match="not found"): - await client_session.experimental.get_task("nonexistent-task") - - # Create a completed task to test get_task_result - completed_task = await store.create_task(TaskMetadata(ttl=60000)) - await store.store_result( - completed_task.task_id, CallToolResult(content=[TextContent(type="text", text="Test result")]) - ) - await store.update_task(completed_task.task_id, status="completed") - - # Test get_task_result (default handler) - payload_result = await client_session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=completed_task.task_id)), - GetTaskPayloadResult, - ) - # The result should have the related-task metadata - assert payload_result.meta is not None - assert "io.modelcontextprotocol/related-task" in payload_result.meta - - # Test cancel_task (default handler) - cancel_result = await client_session.experimental.cancel_task(task.task_id) - assert cancel_result.task_id == task.task_id - assert cancel_result.status == "cancelled" - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_build_elicit_form_request() -> None: - """Test that _build_elicit_form_request builds a proper elicitation request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), - ) as server_session: - # Test without task_id - request = server_session._build_elicit_form_request( - message="Test message", - requested_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, - ) - assert request.method == "elicitation/create" - assert request.params is not None - assert request.params["message"] == "Test message" - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_elicit_form_request( - message="Task message", - requested_schema={"type": "object"}, - related_task_id="test-task-123", - ) - assert request_with_task.method == "elicitation/create" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-123" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_build_elicit_url_request() -> None: - """Test that _build_elicit_url_request builds a proper URL mode elicitation request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), - ) as server_session: - # Test without related_task_id - request = server_session._build_elicit_url_request( - message="Please authorize with GitHub", - url="https://github.com/login/oauth/authorize", - elicitation_id="oauth-123", - ) - assert request.method == "elicitation/create" - assert request.params is not None - assert request.params["message"] == "Please authorize with GitHub" - assert request.params["url"] == "https://github.com/login/oauth/authorize" - assert request.params["elicitationId"] == "oauth-123" - assert request.params["mode"] == "url" - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_elicit_url_request( - message="OAuth required", - url="https://example.com/oauth", - elicitation_id="oauth-456", - related_task_id="test-task-789", - ) - assert request_with_task.method == "elicitation/create" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-789" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_build_create_message_request() -> None: - """Test that _build_create_message_request builds a proper sampling request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - messages = [ - SamplingMessage(role="user", content=TextContent(type="text", text="Hello")), - ] - - # Test without task_id - request = server_session._build_create_message_request( - messages=messages, - max_tokens=100, - system_prompt="You are helpful", - ) - assert request.method == "sampling/createMessage" - assert request.params is not None - assert request.params["maxTokens"] == 100 - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_create_message_request( - messages=messages, - max_tokens=50, - related_task_id="sampling-task-456", - ) - assert request_with_task.method == "sampling/createMessage" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] - == "sampling-task-456" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_send_message() -> None: - """Test that send_message sends a raw session message.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Create a test message - notification = JSONRPCNotification(jsonrpc="2.0", method="test/notification") - message = SessionMessage( - message=notification, - metadata=ServerMessageMetadata(related_request_id="test-req-1"), - ) - - # Send the message - await server_session.send_message(message) - - # Verify it was sent to the stream - received = await server_to_client_receive.receive() - assert isinstance(received.message, JSONRPCNotification) - assert received.message.method == "test/notification" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_success() -> None: - """Test that response routing works for success responses.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track routed responses with event for synchronization - routed_responses: list[dict[str, Any]] = [] - response_received = anyio.Event() - - class TestRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - routed_responses.append({"id": request_id, "response": response}) - response_received.set() - return True # Handled - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - router = TestRouter() - server_session.add_response_router(router) - - # Simulate receiving a response from client - response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=response) - - # Send from "client" side - await client_to_server_send.send(message) - - # Wait for response to be routed - with anyio.fail_after(5): - await response_received.wait() - - # Verify response was routed - assert len(routed_responses) == 1 - assert routed_responses[0]["id"] == "test-req-1" - assert routed_responses[0]["response"]["status"] == "ok" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_error() -> None: - """Test that error routing works for error responses.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track routed errors with event for synchronization - routed_errors: list[dict[str, Any]] = [] - error_received = anyio.Event() - - class TestRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - routed_errors.append({"id": request_id, "error": error}) - error_received.set() - return True # Handled - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - router = TestRouter() - server_session.add_response_router(router) - - # Simulate receiving an error response from client - error_data = ErrorData(code=INVALID_REQUEST, message="Test error") - error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=error_response) - - # Send from "client" side - await client_to_server_send.send(message) - - # Wait for error to be routed - with anyio.fail_after(5): - await error_received.wait() - - # Verify error was routed - assert len(routed_errors) == 1 - assert routed_errors[0]["id"] == "test-req-2" - assert routed_errors[0]["error"].message == "Test error" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_skips_non_matching_routers() -> None: - """Test that routing continues to next router when first doesn't match.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track which routers were called - router_calls: list[str] = [] - response_received = anyio.Event() - - class NonMatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - router_calls.append("non_matching_response") - return False # Doesn't handle it - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - class MatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - router_calls.append("matching_response") - response_received.set() - return True # Handles it - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Add non-matching router first, then matching router - server_session.add_response_router(NonMatchingRouter()) - server_session.add_response_router(MatchingRouter()) - - # Send a response - should skip first router and be handled by second - response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=response) - await client_to_server_send.send(message) - - with anyio.fail_after(5): - await response_received.wait() - - # Verify both routers were called (first returned False, second returned True) - assert router_calls == ["non_matching_response", "matching_response"] - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_error_routing_skips_non_matching_routers() -> None: - """Test that error routing continues to next router when first doesn't match.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track which routers were called - router_calls: list[str] = [] - error_received = anyio.Event() - - class NonMatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - router_calls.append("non_matching_error") - return False # Doesn't handle it - - class MatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - router_calls.append("matching_error") - error_received.set() - return True # Handles it - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Add non-matching router first, then matching router - server_session.add_response_router(NonMatchingRouter()) - server_session.add_response_router(MatchingRouter()) - - # Send an error - should skip first router and be handled by second - error_data = ErrorData(code=INVALID_REQUEST, message="Test error") - error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=error_response) - await client_to_server_send.send(message) - - with anyio.fail_after(5): - await error_received.wait() - - # Verify both routers were called (first returned False, second returned True) - assert router_calls == ["non_matching_error", "matching_error"] - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() diff --git a/tests/experimental/tasks/server/test_server_task_context.py b/tests/experimental/tasks/server/test_server_task_context.py deleted file mode 100644 index e23299698c..0000000000 --- a/tests/experimental/tasks/server/test_server_task_context.py +++ /dev/null @@ -1,709 +0,0 @@ -"""Tests for ServerTaskContext.""" - -import asyncio -from unittest.mock import AsyncMock, Mock - -import anyio -import pytest - -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.types import ( - CallToolResult, - ClientCapabilities, - ClientTasksCapability, - ClientTasksRequestsCapability, - Implementation, - InitializeRequestParams, - JSONRPCRequest, - SamplingMessage, - TaskMetadata, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksSamplingCapability, - TextContent, -) - - -@pytest.mark.anyio -async def test_server_task_context_properties() -> None: - """Test ServerTaskContext property accessors.""" - store = InMemoryTaskStore() - mock_session = Mock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-123") - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - assert ctx.task_id == "test-123" - assert ctx.task.task_id == "test-123" - assert ctx.is_cancelled is False - - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_request_cancellation() -> None: - """Test ServerTaskContext.request_cancellation().""" - store = InMemoryTaskStore() - mock_session = Mock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - assert ctx.is_cancelled is False - ctx.request_cancellation() - assert ctx.is_cancelled is True - - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_update_status_with_notify() -> None: - """Test update_status sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.update_status("Working...", notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_update_status_without_notify() -> None: - """Test update_status skips notification when notify=False.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.update_status("Working...", notify=False) - - mock_session.send_notification.assert_not_called() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_complete_with_notify() -> None: - """Test complete sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - result = CallToolResult(content=[TextContent(type="text", text="Done")]) - await ctx.complete(result, notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_fail_with_notify() -> None: - """Test fail sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.fail("Something went wrong", notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_raises_when_client_lacks_capability() -> None: - """Test that elicit() raises MCPError when client doesn't support elicitation.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=False) - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - with pytest.raises(MCPError) as exc_info: - await ctx.elicit(message="Test?", requested_schema={"type": "object"}) - - assert "elicitation capability" in exc_info.value.error.message - mock_session.check_client_capability.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_raises_when_client_lacks_capability() -> None: - """Test that create_message() raises MCPError when client doesn't support sampling.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=False) - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - with pytest.raises(MCPError) as exc_info: - await ctx.create_message(messages=[], max_tokens=100) - - assert "sampling capability" in exc_info.value.error.message - mock_session.check_client_capability.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_raises_without_handler() -> None: - """Test that elicit() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required"): - await ctx.elicit(message="Test?", requested_schema={"type": "object"}) - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_url_raises_without_handler() -> None: - """Test that elicit_url() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for elicit_url"): - await ctx.elicit_url( - message="Please authorize", - url="https://example.com/oauth", - elicitation_id="oauth-123", - ) - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_raises_without_handler() -> None: - """Test that create_message() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required"): - await ctx.create_message(messages=[], max_tokens=100) - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_queues_request_and_waits_for_response() -> None: - """Test that elicit() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_form_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-1", - method="elicitation/create", - params={"message": "Test?", "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - elicit_result = None - - async def run_elicit() -> None: - nonlocal elicit_result - elicit_result = await ctx.elicit( - message="Test?", - requested_schema={"type": "object"}, - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_elicit) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock elicitation response - msg.resolver.set_result({"action": "accept", "content": {"name": "Alice"}}) - - # Verify result - assert elicit_result is not None - assert elicit_result.action == "accept" - assert elicit_result.content == {"name": "Alice"} - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_url_queues_request_and_waits_for_response() -> None: - """Test that elicit_url() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_url_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-url-req-1", - method="elicitation/create", - params={"message": "Authorize", "url": "https://example.com", "elicitationId": "123", "mode": "url"}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - elicit_result = None - - async def run_elicit_url() -> None: - nonlocal elicit_result - elicit_result = await ctx.elicit_url( - message="Authorize", - url="https://example.com/oauth", - elicitation_id="oauth-123", - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_elicit_url) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock elicitation response (URL mode just returns action) - msg.resolver.set_result({"action": "accept"}) - - # Verify result - assert elicit_result is not None - assert elicit_result.action == "accept" - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_queues_request_and_waits_for_response() -> None: - """Test that create_message() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_create_message_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-2", - method="sampling/createMessage", - params={"messages": [], "maxTokens": 100, "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - sampling_result = None - - async def run_sampling() -> None: - nonlocal sampling_result - sampling_result = await ctx.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_sampling) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock sampling response - msg.resolver.set_result( - { - "role": "assistant", - "content": {"type": "text", "text": "Hello back!"}, - "model": "test-model", - "stopReason": "endTurn", - } - ) - - # Verify result - assert sampling_result is not None - assert sampling_result.role == "assistant" - assert sampling_result.model == "test-model" - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_restores_status_on_cancellation() -> None: - """Test that elicit() restores task status to working when cancelled.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_form_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-cancel", - method="elicitation/create", - params={"message": "Test?", "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - cancelled_error_raised = False - - async with anyio.create_task_group() as tg: - - async def do_elicit() -> None: - nonlocal cancelled_error_raised - try: - await ctx.elicit( - message="Test?", - requested_schema={"type": "object"}, - ) - except anyio.get_cancelled_exc_class(): - cancelled_error_raised = True - # Don't re-raise - let the test continue - - tg.start_soon(do_elicit) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Get the queued message and set cancellation exception on its resolver - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Trigger cancellation by setting exception (use asyncio.CancelledError directly) - msg.resolver.set_exception(asyncio.CancelledError()) - - # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - assert cancelled_error_raised - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_restores_status_on_cancellation() -> None: - """Test that create_message() restores task status to working when cancelled.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_create_message_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-cancel-2", - method="sampling/createMessage", - params={"messages": [], "maxTokens": 100, "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - cancelled_error_raised = False - - async with anyio.create_task_group() as tg: - - async def do_sampling() -> None: - nonlocal cancelled_error_raised - try: - await ctx.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - except anyio.get_cancelled_exc_class(): - cancelled_error_raised = True - # Don't re-raise - - tg.start_soon(do_sampling) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Get the queued message and set cancellation exception on its resolver - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Trigger cancellation by setting exception (use asyncio.CancelledError directly) - msg.resolver.set_exception(asyncio.CancelledError()) - - # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - assert cancelled_error_raised - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_as_task_raises_without_handler() -> None: - """Test that elicit_as_task() raises when handler is not provided.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Create mock session with proper client capabilities - mock_session = Mock() - mock_session.client_params = InitializeRequestParams( - protocol_version="2025-01-01", - capabilities=ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ), - client_info=Implementation(name="test", version="1.0"), - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for elicit_as_task"): - await ctx.elicit_as_task(message="Test?", requested_schema={"type": "object"}) - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_as_task_raises_without_handler() -> None: - """Test that create_message_as_task() raises when handler is not provided.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Create mock session with proper client capabilities - mock_session = Mock() - mock_session.client_params = InitializeRequestParams( - protocol_version="2025-01-01", - capabilities=ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ), - client_info=Implementation(name="test", version="1.0"), - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for create_message_as_task"): - await ctx.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - - store.cleanup() diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py deleted file mode 100644 index 0d431899c8..0000000000 --- a/tests/experimental/tasks/server/test_store.py +++ /dev/null @@ -1,406 +0,0 @@ -"""Tests for InMemoryTaskStore.""" - -from collections.abc import AsyncIterator -from datetime import datetime, timedelta, timezone - -import pytest - -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import cancel_task -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import INVALID_PARAMS, CallToolResult, TaskMetadata, TextContent - - -@pytest.fixture -async def store() -> AsyncIterator[InMemoryTaskStore]: - """Provide a clean InMemoryTaskStore for each test with automatic cleanup.""" - store = InMemoryTaskStore() - yield store - store.cleanup() - - -@pytest.mark.anyio -async def test_create_and_get(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore create and get operations.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - assert task.task_id is not None - assert task.status == "working" - assert task.ttl == 60000 - - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.task_id == task.task_id - assert retrieved.status == "working" - - -@pytest.mark.anyio -async def test_create_with_custom_id(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore create with custom task ID.""" - task = await store.create_task( - metadata=TaskMetadata(ttl=60000), - task_id="my-custom-id", - ) - - assert task.task_id == "my-custom-id" - assert task.status == "working" - - retrieved = await store.get_task("my-custom-id") - assert retrieved is not None - assert retrieved.task_id == "my-custom-id" - - -@pytest.mark.anyio -async def test_create_duplicate_id_raises(store: InMemoryTaskStore) -> None: - """Test that creating a task with duplicate ID raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") - - with pytest.raises(ValueError, match="already exists"): - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") - - -@pytest.mark.anyio -async def test_get_nonexistent_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting a nonexistent task returns None.""" - retrieved = await store.get_task("nonexistent") - assert retrieved is None - - -@pytest.mark.anyio -async def test_update_status(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore status updates.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - updated = await store.update_task(task.task_id, status="completed", status_message="All done!") - - assert updated.status == "completed" - assert updated.status_message == "All done!" - - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.status == "completed" - assert retrieved.status_message == "All done!" - - -@pytest.mark.anyio -async def test_update_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that updating a nonexistent task raises.""" - with pytest.raises(ValueError, match="not found"): - await store.update_task("nonexistent", status="completed") - - -@pytest.mark.anyio -async def test_store_and_get_result(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore result storage and retrieval.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # Store result - result = CallToolResult(content=[TextContent(type="text", text="Result data")]) - await store.store_result(task.task_id, result) - - # Retrieve result - retrieved_result = await store.get_result(task.task_id) - assert retrieved_result == result - - -@pytest.mark.anyio -async def test_get_result_nonexistent_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting result for nonexistent task returns None.""" - result = await store.get_result("nonexistent") - assert result is None - - -@pytest.mark.anyio -async def test_get_result_no_result_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting result when none stored returns None.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - result = await store.get_result(task.task_id) - assert result is None - - -@pytest.mark.anyio -async def test_list_tasks(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore list operation.""" - # Create multiple tasks - for _ in range(3): - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - tasks, next_cursor = await store.list_tasks() - assert len(tasks) == 3 - assert next_cursor is None # Less than page size - - -@pytest.mark.anyio -async def test_list_tasks_pagination() -> None: - """Test InMemoryTaskStore pagination.""" - # Needs custom page_size, can't use fixture - store = InMemoryTaskStore(page_size=2) - - # Create 5 tasks - for _ in range(5): - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # First page - tasks, next_cursor = await store.list_tasks() - assert len(tasks) == 2 - assert next_cursor is not None - - # Second page - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) - assert len(tasks) == 2 - assert next_cursor is not None - - # Third page (last) - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) - assert len(tasks) == 1 - assert next_cursor is None - - store.cleanup() - - -@pytest.mark.anyio -async def test_list_tasks_invalid_cursor(store: InMemoryTaskStore) -> None: - """Test that invalid cursor raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - with pytest.raises(ValueError, match="Invalid cursor"): - await store.list_tasks(cursor="invalid-cursor") - - -@pytest.mark.anyio -async def test_delete_task(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore delete operation.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - deleted = await store.delete_task(task.task_id) - assert deleted is True - - retrieved = await store.get_task(task.task_id) - assert retrieved is None - - # Delete non-existent - deleted = await store.delete_task(task.task_id) - assert deleted is False - - -@pytest.mark.anyio -async def test_get_all_tasks_helper(store: InMemoryTaskStore) -> None: - """Test the get_all_tasks debugging helper.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - all_tasks = store.get_all_tasks() - assert len(all_tasks) == 2 - - -@pytest.mark.anyio -async def test_store_result_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that storing result for nonexistent task raises ValueError.""" - result = CallToolResult(content=[TextContent(type="text", text="Result")]) - - with pytest.raises(ValueError, match="not found"): - await store.store_result("nonexistent-id", result) - - -@pytest.mark.anyio -async def test_create_task_with_null_ttl(store: InMemoryTaskStore) -> None: - """Test creating task with null TTL (never expires).""" - task = await store.create_task(metadata=TaskMetadata(ttl=None)) - - assert task.ttl is None - - # Task should persist (not expire) - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - - -@pytest.mark.anyio -async def test_task_expiration_cleanup(store: InMemoryTaskStore) -> None: - """Test that expired tasks are cleaned up lazily.""" - # Create a task with very short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=1)) # 1ms TTL - - # Manually force the expiry to be in the past - stored = store._tasks.get(task.task_id) - assert stored is not None - stored.expires_at = datetime.now(timezone.utc) - timedelta(seconds=10) - - # Task should still exist in internal dict but be expired - assert task.task_id in store._tasks - - # Any access operation should clean up expired tasks - # list_tasks triggers cleanup - tasks, _ = await store.list_tasks() - - # Expired task should be cleaned up - assert task.task_id not in store._tasks - assert len(tasks) == 0 - - -@pytest.mark.anyio -async def test_task_with_null_ttl_never_expires(store: InMemoryTaskStore) -> None: - """Test that tasks with null TTL never expire during cleanup.""" - # Create task with null TTL - task = await store.create_task(metadata=TaskMetadata(ttl=None)) - - # Verify internal storage has no expiry - stored = store._tasks.get(task.task_id) - assert stored is not None - assert stored.expires_at is None - - # Access operations should NOT remove this task - await store.list_tasks() - await store.get_task(task.task_id) - - # Task should still exist - assert task.task_id in store._tasks - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - - -@pytest.mark.anyio -async def test_terminal_task_ttl_reset(store: InMemoryTaskStore) -> None: - """Test that TTL is reset when task enters terminal state.""" - # Create task with short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) # 60s - - # Get the initial expiry - stored = store._tasks.get(task.task_id) - assert stored is not None - initial_expiry = stored.expires_at - assert initial_expiry is not None - - # Update to terminal state (completed) - await store.update_task(task.task_id, status="completed") - - # Expiry should be reset to a new time (from now + TTL) - new_expiry = stored.expires_at - assert new_expiry is not None - assert new_expiry >= initial_expiry - - -@pytest.mark.anyio -async def test_terminal_status_transition_rejected(store: InMemoryTaskStore) -> None: - """Test that transitions from terminal states are rejected. - - Per spec: Terminal states (completed, failed, cancelled) MUST NOT - transition to any other status. - """ - # Test each terminal status - for terminal_status in ("completed", "failed", "cancelled"): - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # Move to terminal state - await store.update_task(task.task_id, status=terminal_status) - - # Attempting to transition to any other status should raise - with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status="working") - - # Also test transitioning to another terminal state - other_terminal = "failed" if terminal_status != "failed" else "completed" - with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status=other_terminal) - - -@pytest.mark.anyio -async def test_terminal_status_allows_same_status(store: InMemoryTaskStore) -> None: - """Test that setting the same terminal status doesn't raise. - - This is not a transition, so it should be allowed (no-op). - """ - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") - - # Setting the same status should not raise - updated = await store.update_task(task.task_id, status="completed") - assert updated.status == "completed" - - # Updating just the message should also work - updated = await store.update_task(task.task_id, status_message="Updated message") - assert updated.status_message == "Updated message" - - -@pytest.mark.anyio -async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that wait_for_update raises for nonexistent task.""" - with pytest.raises(ValueError, match="not found"): - await store.wait_for_update("nonexistent-task-id") - - -@pytest.mark.anyio -async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None: - """Test cancel_task helper succeeds for a working task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - assert task.status == "working" - - result = await cancel_task(store, task.task_id) - - assert result.task_id == task.task_id - assert result.status == "cancelled" - - # Verify store is updated - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.status == "cancelled" - - -@pytest.mark.anyio -async def test_cancel_task_rejects_nonexistent_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for nonexistent task.""" - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, "nonexistent-task-id") - - assert exc_info.value.error.code == INVALID_PARAMS - assert "not found" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_completed_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for completed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'completed'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_failed_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for failed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="failed") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'failed'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_already_cancelled_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for already cancelled task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="cancelled") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'cancelled'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_succeeds_for_input_required_task(store: InMemoryTaskStore) -> None: - """Test cancel_task helper succeeds for a task in input_required status.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="input_required") - - result = await cancel_task(store, task.task_id) - - assert result.task_id == task.task_id - assert result.status == "cancelled" diff --git a/tests/experimental/tasks/server/test_task_result_handler.py b/tests/experimental/tasks/server/test_task_result_handler.py deleted file mode 100644 index 8b5a03ce2b..0000000000 --- a/tests/experimental/tasks/server/test_task_result_handler.py +++ /dev/null @@ -1,354 +0,0 @@ -"""Tests for TaskResultHandler.""" - -from collections.abc import AsyncIterator -from typing import Any -from unittest.mock import AsyncMock, Mock - -import anyio -import pytest - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.message import SessionMessage -from mcp.types import ( - INVALID_REQUEST, - CallToolResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - JSONRPCRequest, - TaskMetadata, - TextContent, -) - - -@pytest.fixture -async def store() -> AsyncIterator[InMemoryTaskStore]: - """Provide a clean store for each test.""" - s = InMemoryTaskStore() - yield s - s.cleanup() - - -@pytest.fixture -def queue() -> InMemoryTaskMessageQueue: - """Provide a clean queue for each test.""" - return InMemoryTaskMessageQueue() - - -@pytest.fixture -def handler(store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue) -> TaskResultHandler: - """Provide a handler for each test.""" - return TaskResultHandler(store, queue) - - -@pytest.mark.anyio -async def test_handle_returns_result_for_completed_task( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() returns the stored result for a completed task.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - result = CallToolResult(content=[TextContent(type="text", text="Done!")]) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") - - assert response is not None - assert response.meta is not None - assert "io.modelcontextprotocol/related-task" in response.meta - - -@pytest.mark.anyio -async def test_handle_raises_for_nonexistent_task( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() raises MCPError for nonexistent task.""" - mock_session = Mock() - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent")) - - with pytest.raises(MCPError) as exc_info: - await handler.handle(request, mock_session, "req-1") - - assert "not found" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_handle_returns_empty_result_when_no_result_stored( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() returns minimal result when task completed without stored result.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - await store.update_task(task.task_id, status="completed") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") - - assert response is not None - assert response.meta is not None - assert "io.modelcontextprotocol/related-task" in response.meta - - -@pytest.mark.anyio -async def test_handle_delivers_queued_messages( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() delivers queued messages before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - queued_msg = QueuedMessage( - type="notification", - message=JSONRPCRequest( - jsonrpc="2.0", - id="notif-1", - method="test/notification", - params={}, - ), - ) - await queue.enqueue(task.task_id, queued_msg) - await store.update_task(task.task_id, status="completed") - - sent_messages: list[SessionMessage] = [] - - async def track_send(msg: SessionMessage) -> None: - sent_messages.append(msg) - - mock_session = Mock() - mock_session.send_message = track_send - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - await handler.handle(request, mock_session, "req-1") - - assert len(sent_messages) == 1 - - -@pytest.mark.anyio -async def test_handle_waits_for_task_completion( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() waits for task to complete before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - result_holder: list[GetTaskPayloadResult | None] = [None] - - async def run_handle() -> None: - result_holder[0] = await handler.handle(request, mock_session, "req-1") - - async with anyio.create_task_group() as tg: - tg.start_soon(run_handle) - - # Wait for handler to start waiting (event gets created when wait starts) - while task.task_id not in store._update_events: - await anyio.sleep(0) - - await store.store_result(task.task_id, CallToolResult(content=[TextContent(type="text", text="Done")])) - await store.update_task(task.task_id, status="completed") - - assert result_holder[0] is not None - - -@pytest.mark.anyio -async def test_route_response_resolves_pending_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() resolves a pending request.""" - resolver: Resolver[dict[str, Any]] = Resolver() - handler._pending_requests["req-123"] = resolver - - result = handler.route_response("req-123", {"status": "ok"}) - - assert result is True - assert resolver.done() - assert await resolver.wait() == {"status": "ok"} - - -@pytest.mark.anyio -async def test_route_response_returns_false_for_unknown_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() returns False for unknown request ID.""" - result = handler.route_response("unknown-req", {"status": "ok"}) - assert result is False - - -@pytest.mark.anyio -async def test_route_response_returns_false_for_already_done_resolver( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() returns False if resolver already completed.""" - resolver: Resolver[dict[str, Any]] = Resolver() - resolver.set_result({"already": "done"}) - handler._pending_requests["req-123"] = resolver - - result = handler.route_response("req-123", {"new": "data"}) - - assert result is False - - -@pytest.mark.anyio -async def test_route_error_resolves_pending_request_with_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_error() sets exception on pending request.""" - resolver: Resolver[dict[str, Any]] = Resolver() - handler._pending_requests["req-123"] = resolver - - error = ErrorData(code=INVALID_REQUEST, message="Something went wrong") - result = handler.route_error("req-123", error) - - assert result is True - assert resolver.done() - - with pytest.raises(MCPError) as exc_info: - await resolver.wait() - assert exc_info.value.error.message == "Something went wrong" - - -@pytest.mark.anyio -async def test_route_error_returns_false_for_unknown_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_error() returns False for unknown request ID.""" - error = ErrorData(code=INVALID_REQUEST, message="Error") - result = handler.route_error("unknown-req", error) - assert result is False - - -@pytest.mark.anyio -async def test_deliver_registers_resolver_for_request_messages( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _deliver_queued_messages registers resolvers for request messages.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - resolver: Resolver[dict[str, Any]] = Resolver() - queued_msg = QueuedMessage( - type="request", - message=JSONRPCRequest( - jsonrpc="2.0", - id="inner-req-1", - method="elicitation/create", - params={}, - ), - resolver=resolver, - original_request_id="inner-req-1", - ) - await queue.enqueue(task.task_id, queued_msg) - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - await handler._deliver_queued_messages(task.task_id, mock_session, "outer-req-1") - - assert "inner-req-1" in handler._pending_requests - assert handler._pending_requests["inner-req-1"] is resolver - - -@pytest.mark.anyio -async def test_deliver_skips_resolver_registration_when_no_original_id( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _deliver_queued_messages skips resolver registration when original_request_id is None.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - resolver: Resolver[dict[str, Any]] = Resolver() - queued_msg = QueuedMessage( - type="request", - message=JSONRPCRequest( - jsonrpc="2.0", - id="inner-req-1", - method="elicitation/create", - params={}, - ), - resolver=resolver, - original_request_id=None, # No original request ID - ) - await queue.enqueue(task.task_id, queued_msg) - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - await handler._deliver_queued_messages(task.task_id, mock_session, "outer-req-1") - - # Resolver should NOT be registered since original_request_id is None - assert len(handler._pending_requests) == 0 - # But the message should still be sent - mock_session.send_message.assert_called_once() - - -@pytest.mark.anyio -async def test_wait_for_task_update_handles_store_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _wait_for_task_update handles store exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - # Make wait_for_update raise an exception - async def failing_wait(task_id: str) -> None: - raise RuntimeError("Store error") - - store.wait_for_update = failing_wait # type: ignore[method-assign] - - # Queue a message to unblock the race via the queue path - async def enqueue_later() -> None: - # Wait for queue to start waiting (event gets created when wait starts) - while task.task_id not in queue._events: - await anyio.sleep(0) - await queue.enqueue( - task.task_id, - QueuedMessage( - type="notification", - message=JSONRPCRequest( - jsonrpc="2.0", - id="notif-1", - method="test/notification", - params={}, - ), - ), - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(enqueue_later) - # This should complete via the queue path even though store raises - await handler._wait_for_task_update(task.task_id) - - -@pytest.mark.anyio -async def test_wait_for_task_update_handles_queue_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _wait_for_task_update handles queue exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - # Make wait_for_message raise an exception - async def failing_wait(task_id: str) -> None: - raise RuntimeError("Queue error") - - queue.wait_for_message = failing_wait # type: ignore[method-assign] - - # Update the store to unblock the race via the store path - async def update_later() -> None: - # Wait for store to start waiting (event gets created when wait starts) - while task.task_id not in store._update_events: - await anyio.sleep(0) - await store.update_task(task.task_id, status="completed") - - async with anyio.create_task_group() as tg: - tg.start_soon(update_later) - # This should complete via the store path even though queue raises - await handler._wait_for_task_update(task.task_id) diff --git a/tests/experimental/tasks/test_capabilities.py b/tests/experimental/tasks/test_capabilities.py deleted file mode 100644 index 90a8656ba0..0000000000 --- a/tests/experimental/tasks/test_capabilities.py +++ /dev/null @@ -1,283 +0,0 @@ -"""Tests for tasks capability checking utilities.""" - -import pytest - -from mcp import MCPError -from mcp.shared.experimental.tasks.capabilities import ( - check_tasks_capability, - has_task_augmented_elicitation, - has_task_augmented_sampling, - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.types import ( - ClientCapabilities, - ClientTasksCapability, - ClientTasksRequestsCapability, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksSamplingCapability, -) - - -class TestCheckTasksCapability: - """Tests for check_tasks_capability function.""" - - def test_required_requests_none_returns_true(self) -> None: - """When required.requests is None, should return True.""" - required = ClientTasksCapability() - client = ClientTasksCapability() - assert check_tasks_capability(required, client) is True - - def test_client_requests_none_returns_false(self) -> None: - """When client.requests is None but required.requests is set, should return False.""" - required = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - client = ClientTasksCapability() - assert check_tasks_capability(required, client) is False - - def test_elicitation_required_but_client_missing(self) -> None: - """When elicitation is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) - ) - client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - assert check_tasks_capability(required, client) is False - - def test_elicitation_create_required_but_client_missing(self) -> None: - """When elicitation.create is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability() # No create - ) - ) - assert check_tasks_capability(required, client) is False - - def test_elicitation_create_present(self) -> None: - """When elicitation.create is required and client has it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_sampling_required_but_client_missing(self) -> None: - """When sampling is required but client doesn't have it.""" - required = ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) - client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - assert check_tasks_capability(required, client) is False - - def test_sampling_create_message_required_but_client_missing(self) -> None: - """When sampling.createMessage is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability() # No createMessage - ) - ) - assert check_tasks_capability(required, client) is False - - def test_sampling_create_message_present(self) -> None: - """When sampling.createMessage is required and client has it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_both_elicitation_and_sampling_present(self) -> None: - """When both elicitation.create and sampling.createMessage are required and client has both.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()), - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()), - ) - ) - assert check_tasks_capability(required, client) is True - - def test_elicitation_without_create_required(self) -> None: - """When elicitation is required but not create specifically.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability() # No create - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_sampling_without_create_message_required(self) -> None: - """When sampling is required but not createMessage specifically.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability() # No createMessage - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - -class TestHasTaskAugmentedElicitation: - """Tests for has_task_augmented_elicitation function.""" - - def test_tasks_none(self) -> None: - """Returns False when caps.tasks is None.""" - caps = ClientCapabilities() - assert has_task_augmented_elicitation(caps) is False - - def test_requests_none(self) -> None: - """Returns False when caps.tasks.requests is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability()) - assert has_task_augmented_elicitation(caps) is False - - def test_elicitation_none(self) -> None: - """Returns False when caps.tasks.requests.elicitation is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) - assert has_task_augmented_elicitation(caps) is False - - def test_create_none(self) -> None: - """Returns False when caps.tasks.requests.elicitation.create is None.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) - ) - ) - assert has_task_augmented_elicitation(caps) is False - - def test_create_present(self) -> None: - """Returns True when full capability path is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ) - assert has_task_augmented_elicitation(caps) is True - - -class TestHasTaskAugmentedSampling: - """Tests for has_task_augmented_sampling function.""" - - def test_tasks_none(self) -> None: - """Returns False when caps.tasks is None.""" - caps = ClientCapabilities() - assert has_task_augmented_sampling(caps) is False - - def test_requests_none(self) -> None: - """Returns False when caps.tasks.requests is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability()) - assert has_task_augmented_sampling(caps) is False - - def test_sampling_none(self) -> None: - """Returns False when caps.tasks.requests.sampling is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) - assert has_task_augmented_sampling(caps) is False - - def test_create_message_none(self) -> None: - """Returns False when caps.tasks.requests.sampling.createMessage is None.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) - ) - assert has_task_augmented_sampling(caps) is False - - def test_create_message_present(self) -> None: - """Returns True when full capability path is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ) - assert has_task_augmented_sampling(caps) is True - - -class TestRequireTaskAugmentedElicitation: - """Tests for require_task_augmented_elicitation function.""" - - def test_raises_when_none(self) -> None: - """Raises MCPError when client_caps is None.""" - with pytest.raises(MCPError) as exc_info: - require_task_augmented_elicitation(None) - assert "task-augmented elicitation" in str(exc_info.value) - - def test_raises_when_missing(self) -> None: - """Raises MCPError when capability is missing.""" - caps = ClientCapabilities() - with pytest.raises(MCPError) as exc_info: - require_task_augmented_elicitation(caps) - assert "task-augmented elicitation" in str(exc_info.value) - - def test_passes_when_present(self) -> None: - """Does not raise when capability is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ) - require_task_augmented_elicitation(caps) - - -class TestRequireTaskAugmentedSampling: - """Tests for require_task_augmented_sampling function.""" - - def test_raises_when_none(self) -> None: - """Raises MCPError when client_caps is None.""" - with pytest.raises(MCPError) as exc_info: - require_task_augmented_sampling(None) - assert "task-augmented sampling" in str(exc_info.value) - - def test_raises_when_missing(self) -> None: - """Raises MCPError when capability is missing.""" - caps = ClientCapabilities() - with pytest.raises(MCPError) as exc_info: - require_task_augmented_sampling(caps) - assert "task-augmented sampling" in str(exc_info.value) - - def test_passes_when_present(self) -> None: - """Does not raise when capability is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ) - require_task_augmented_sampling(caps) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py deleted file mode 100644 index 2d0378a9ce..0000000000 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ /dev/null @@ -1,695 +0,0 @@ -"""Tests for the four elicitation scenarios with tasks. - -This tests all combinations of tool call types and elicitation types: -1. Normal tool call + Normal elicitation (session.elicit) -2. Normal tool call + Task-augmented elicitation (session.experimental.elicit_as_task) -3. Task-augmented tool call + Normal elicitation (task.elicit) -4. Task-augmented tool call + Task-augmented elicitation (task.elicit_as_task) - -And the same for sampling (create_message). -""" - -from typing import Any - -import anyio -import pytest -from anyio import Event - -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import NotificationOptions -from mcp.shared._context import RequestContext -from mcp.shared.experimental.tasks.helpers import is_terminal -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.types import ( - TASK_REQUIRED, - CallToolRequestParams, - CallToolResult, - CreateMessageRequestParams, - CreateMessageResult, - CreateTaskResult, - ElicitRequestParams, - ElicitResult, - ErrorData, - GetTaskPayloadResult, - GetTaskResult, - ListToolsResult, - PaginatedRequestParams, - SamplingMessage, - TaskMetadata, - TextContent, - Tool, -) - - -def create_client_task_handlers( - client_task_store: InMemoryTaskStore, - elicit_received: Event, -) -> ExperimentalTaskHandlers: - """Create task handlers for client to handle task-augmented elicitation from server.""" - - elicit_response = ElicitResult(action="accept", content={"confirm": True}) - task_complete_events: dict[str, Event] = {} - - async def handle_augmented_elicitation( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - """Handle task-augmented elicitation by creating a client-side task.""" - elicit_received.set() - task = await client_task_store.create_task(task_metadata) - task_complete_events[task.task_id] = Event() - - async def complete_task() -> None: - # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, elicit_response) - await client_task_store.update_task(task.task_id, status="completed") - task_complete_events[task.task_id].set() - - context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] - return CreateTaskResult(task=task) - - async def handle_get_task( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskResult: - """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) - assert task is not None, f"Task not found: {params.task_id}" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=100, - ) - - async def handle_get_task_result( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskPayloadResult | ErrorData: - """Handle tasks/result from server.""" - event = task_complete_events.get(params.task_id) - assert event is not None, f"No completion event for task: {params.task_id}" - await event.wait() - result = await client_task_store.get_result(params.task_id) - assert result is not None, f"Result not found for task: {params.task_id}" - return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) - - return ExperimentalTaskHandlers( - augmented_elicitation=handle_augmented_elicitation, - get_task=handle_get_task, - get_task_result=handle_get_task_result, - ) - - -def create_sampling_task_handlers( - client_task_store: InMemoryTaskStore, - sampling_received: Event, -) -> ExperimentalTaskHandlers: - """Create task handlers for client to handle task-augmented sampling from server.""" - - sampling_response = CreateMessageResult( - role="assistant", - content=TextContent(type="text", text="Hello from the model!"), - model="test-model", - ) - task_complete_events: dict[str, Event] = {} - - async def handle_augmented_sampling( - context: RequestContext[ClientSession], - params: CreateMessageRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - """Handle task-augmented sampling by creating a client-side task.""" - sampling_received.set() - task = await client_task_store.create_task(task_metadata) - task_complete_events[task.task_id] = Event() - - async def complete_task() -> None: - # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, sampling_response) - await client_task_store.update_task(task.task_id, status="completed") - task_complete_events[task.task_id].set() - - context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] - return CreateTaskResult(task=task) - - async def handle_get_task( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskResult: - """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) - assert task is not None, f"Task not found: {params.task_id}" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=100, - ) - - async def handle_get_task_result( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskPayloadResult | ErrorData: - """Handle tasks/result from server.""" - event = task_complete_events.get(params.task_id) - assert event is not None, f"No completion event for task: {params.task_id}" - await event.wait() - result = await client_task_store.get_result(params.task_id) - assert result is not None, f"Result not found for task: {params.task_id}" - return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) - - return ExperimentalTaskHandlers( - augmented_sampling=handle_augmented_sampling, - get_task=handle_get_task, - get_task_result=handle_get_task_result, - ) - - -@pytest.mark.anyio -async def test_scenario1_normal_tool_normal_elicitation() -> None: - """Scenario 1: Normal tool call with normal elicitation. - - Server calls session.elicit() directly, client responds immediately. - """ - elicit_received = Event() - tool_result: list[str] = [] - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Normal elicitation - expects immediate response - result = await ctx.session.elicit( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - tool_result.append("confirmed" if confirmed else "cancelled") - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - server = Server("test-scenario1", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - # Elicitation callback for client - async def elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - ) -> ElicitResult: - elicit_received.set() - return ElicitResult(action="accept", content={"confirm": True}) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("confirm_action", {}) - - # Verify elicitation was received and tool completed - assert elicit_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "confirmed" - - -@pytest.mark.anyio -async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: - """Scenario 2: Normal tool call with task-augmented elicitation. - - Server calls session.experimental.elicit_as_task(), client creates a task - for the elicitation and returns CreateTaskResult. Server polls client. - """ - elicit_received = Event() - tool_result: list[str] = [] - - # Client-side task store for handling task-augmented elicitation - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Task-augmented elicitation - server polls client - result = await ctx.session.experimental.elicit_as_task( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ttl=60000, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - tool_result.append("confirmed" if confirmed else "cancelled") - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - server = Server("test-scenario2", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - task_handlers = create_client_task_handlers(client_task_store, elicit_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("confirm_action", {}) - - # Verify elicitation was received and tool completed - assert elicit_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "confirmed" - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: - """Scenario 3: Task-augmented tool call with normal elicitation. - - Client calls tool as task. Inside the task, server uses task.elicit() - which queues the request and delivers via tasks/result. - """ - elicit_received = Event() - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Normal elicitation within task - queued and delivered via tasks/result - result = await task.elicit( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario3", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - - # Elicitation callback for client - async def elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - ) -> ElicitResult: - elicit_received.set() - return ElicitResult(action="accept", content={"confirm": True}) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required, then call tasks/result - found_input_required = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required": # pragma: no branch - found_input_required = True - break - assert found_input_required, "Expected to see input_required status" - - # This will deliver the elicitation and get the response - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert elicit_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - - -@pytest.mark.anyio -async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> None: - """Scenario 4: Task-augmented tool call with task-augmented elicitation. - - Client calls tool as task. Inside the task, server uses task.elicit_as_task() - which sends task-augmented elicitation. Client creates its own task for the - elicitation, and server polls the client. - - This tests the full bidirectional flow where: - 1. Client calls tasks/result on server (for tool task) - 2. Server delivers task-augmented elicitation through that stream - 3. Client creates its own task and returns CreateTaskResult - 4. Server polls the client's task while the client's tasks/result is still open - 5. Server gets the ElicitResult and completes the tool task - 6. Client's tasks/result returns with the CallToolResult - """ - elicit_received = Event() - work_completed = Event() - - # Client-side task store for handling task-augmented elicitation - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Task-augmented elicitation within task - server polls client - result = await task.elicit_as_task( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ttl=60000, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario4", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - task_handlers = create_client_task_handlers(client_task_store, elicit_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required or terminal, then call tasks/result - found_expected_status = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required" or is_terminal(status.status): # pragma: no branch - found_expected_status = True - break - assert found_expected_status, "Expected to see input_required or terminal status" - - # This will deliver the task-augmented elicitation, - # server will poll client, and eventually return the tool result - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert elicit_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: - """Scenario 2 for sampling: Normal tool call with task-augmented sampling. - - Server calls session.experimental.create_message_as_task(), client creates - a task for the sampling and returns CreateTaskResult. Server polls client. - """ - sampling_received = Event() - tool_result: list[str] = [] - - # Client-side task store for handling task-augmented sampling - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Task-augmented sampling - server polls client - result = await ctx.session.experimental.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ttl=60000, - ) - - assert isinstance(result.content, TextContent), "Expected TextContent response" - response_text = result.content.text - - tool_result.append(response_text) - return CallToolResult(content=[TextContent(type="text", text=response_text)]) - - server = Server("test-scenario2-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("generate_text", {}) - - # Verify sampling was received and tool completed - assert sampling_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Hello from the model!" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "Hello from the model!" - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() -> None: - """Scenario 4 for sampling: Task-augmented tool call with task-augmented sampling. - - Client calls tool as task. Inside the task, server uses task.create_message_as_task() - which sends task-augmented sampling. Client creates its own task for the sampling, - and server polls the client. - """ - sampling_received = Event() - work_completed = Event() - - # Client-side task store for handling task-augmented sampling - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Task-augmented sampling within task - server polls client - result = await task.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ttl=60000, - ) - - assert isinstance(result.content, TextContent), "Expected TextContent response" - response_text = result.content.text - - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text=response_text)]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario4-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("generate_text", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required or terminal - found_expected_status = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required" or is_terminal(status.status): # pragma: no branch - found_expected_status = True - break - assert found_expected_status, "Expected to see input_required or terminal status" - - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert sampling_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "Hello from the model!" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - client_task_store.cleanup() diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py deleted file mode 100644 index eca113d5b4..0000000000 --- a/tests/experimental/tasks/test_message_queue.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Tests for TaskMessageQueue and InMemoryTaskMessageQueue.""" - -from collections import deque -from datetime import datetime, timezone - -import anyio -import pytest - -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.types import JSONRPCNotification, JSONRPCRequest - - -@pytest.fixture -def queue() -> InMemoryTaskMessageQueue: - return InMemoryTaskMessageQueue() - - -def make_request(id: int = 1, method: str = "test/method") -> JSONRPCRequest: - return JSONRPCRequest(jsonrpc="2.0", id=id, method=method) - - -def make_notification(method: str = "test/notify") -> JSONRPCNotification: - return JSONRPCNotification(jsonrpc="2.0", method=method) - - -class TestInMemoryTaskMessageQueue: - @pytest.mark.anyio - async def test_enqueue_and_dequeue(self, queue: InMemoryTaskMessageQueue) -> None: - """Test basic enqueue and dequeue operations.""" - task_id = "task-1" - msg = QueuedMessage(type="request", message=make_request()) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.type == "request" - assert result.message.method == "test/method" - - @pytest.mark.anyio - async def test_dequeue_empty_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: - """Dequeue from empty queue returns None.""" - result = await queue.dequeue("nonexistent-task") - assert result is None - - @pytest.mark.anyio - async def test_fifo_ordering(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages are dequeued in FIFO order.""" - task_id = "task-1" - - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1, "first"))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2, "second"))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3, "third"))) - - msg1 = await queue.dequeue(task_id) - msg2 = await queue.dequeue(task_id) - msg3 = await queue.dequeue(task_id) - - assert msg1 is not None and msg1.message.method == "first" - assert msg2 is not None and msg2.message.method == "second" - assert msg3 is not None and msg3.message.method == "third" - - @pytest.mark.anyio - async def test_separate_queues_per_task(self, queue: InMemoryTaskMessageQueue) -> None: - """Each task has its own queue.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1, "task1-msg"))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2, "task2-msg"))) - - msg1 = await queue.dequeue("task-1") - msg2 = await queue.dequeue("task-2") - - assert msg1 is not None and msg1.message.method == "task1-msg" - assert msg2 is not None and msg2.message.method == "task2-msg" - - @pytest.mark.anyio - async def test_peek_does_not_remove(self, queue: InMemoryTaskMessageQueue) -> None: - """Peek returns message without removing it.""" - task_id = "task-1" - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - peeked = await queue.peek(task_id) - dequeued = await queue.dequeue(task_id) - - assert peeked is not None - assert dequeued is not None - assert isinstance(peeked.message, JSONRPCRequest) - assert isinstance(dequeued.message, JSONRPCRequest) - assert peeked.message.id == dequeued.message.id - - @pytest.mark.anyio - async def test_is_empty(self, queue: InMemoryTaskMessageQueue) -> None: - """Test is_empty method.""" - task_id = "task-1" - - assert await queue.is_empty(task_id) is True - - await queue.enqueue(task_id, QueuedMessage(type="notification", message=make_notification())) - assert await queue.is_empty(task_id) is False - - await queue.dequeue(task_id) - assert await queue.is_empty(task_id) is True - - @pytest.mark.anyio - async def test_clear_returns_all_messages(self, queue: InMemoryTaskMessageQueue) -> None: - """Clear removes and returns all messages.""" - task_id = "task-1" - - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3))) - - messages = await queue.clear(task_id) - - assert len(messages) == 3 - assert await queue.is_empty(task_id) is True - - @pytest.mark.anyio - async def test_clear_empty_queue(self, queue: InMemoryTaskMessageQueue) -> None: - """Clear on empty queue returns empty list.""" - messages = await queue.clear("nonexistent") - assert messages == [] - - @pytest.mark.anyio - async def test_notification_messages(self, queue: InMemoryTaskMessageQueue) -> None: - """Test queuing notification messages.""" - task_id = "task-1" - msg = QueuedMessage(type="notification", message=make_notification("log/message")) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.type == "notification" - assert result.message.method == "log/message" - - @pytest.mark.anyio - async def test_message_timestamp(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages have timestamps.""" - before = datetime.now(timezone.utc) - msg = QueuedMessage(type="request", message=make_request()) - after = datetime.now(timezone.utc) - - assert before <= msg.timestamp <= after - - @pytest.mark.anyio - async def test_message_with_resolver(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages can have resolvers.""" - task_id = "task-1" - resolver: Resolver[dict[str, str]] = Resolver() - - msg = QueuedMessage( - type="request", - message=make_request(), - resolver=resolver, - original_request_id=42, - ) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.resolver is resolver - assert result.original_request_id == 42 - - @pytest.mark.anyio - async def test_cleanup_specific_task(self, queue: InMemoryTaskMessageQueue) -> None: - """Cleanup removes specific task's data.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) - - queue.cleanup("task-1") - - assert await queue.is_empty("task-1") is True - assert await queue.is_empty("task-2") is False - - @pytest.mark.anyio - async def test_cleanup_all(self, queue: InMemoryTaskMessageQueue) -> None: - """Cleanup without task_id removes all data.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) - - queue.cleanup() - - assert await queue.is_empty("task-1") is True - assert await queue.is_empty("task-2") is True - - @pytest.mark.anyio - async def test_wait_for_message_returns_immediately_if_message_exists( - self, queue: InMemoryTaskMessageQueue - ) -> None: - """wait_for_message returns immediately if queue not empty.""" - task_id = "task-1" - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - # Should return immediately, not block - with anyio.fail_after(1): - await queue.wait_for_message(task_id) - - @pytest.mark.anyio - async def test_wait_for_message_blocks_until_message(self, queue: InMemoryTaskMessageQueue) -> None: - """wait_for_message blocks until a message is enqueued.""" - task_id = "task-1" - received = False - waiter_started = anyio.Event() - - async def enqueue_when_ready() -> None: - # Wait until the waiter has started before enqueueing - await waiter_started.wait() - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - async def wait_for_msg() -> None: - nonlocal received - # Signal that we're about to start waiting - waiter_started.set() - await queue.wait_for_message(task_id) - received = True - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_for_msg) - tg.start_soon(enqueue_when_ready) - - assert received is True - - @pytest.mark.anyio - async def test_notify_message_available_wakes_waiter(self, queue: InMemoryTaskMessageQueue) -> None: - """notify_message_available wakes up waiting coroutines.""" - task_id = "task-1" - notified = False - waiter_started = anyio.Event() - - async def notify_when_ready() -> None: - # Wait until the waiter has started before notifying - await waiter_started.wait() - await queue.notify_message_available(task_id) - - async def wait_for_notification() -> None: - nonlocal notified - # Signal that we're about to start waiting - waiter_started.set() - await queue.wait_for_message(task_id) - notified = True - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_for_notification) - tg.start_soon(notify_when_ready) - - assert notified is True - - @pytest.mark.anyio - async def test_peek_empty_queue_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: - """Peek on empty queue returns None.""" - result = await queue.peek("nonexistent-task") - assert result is None - - @pytest.mark.anyio - async def test_wait_for_message_double_check_race_condition(self, queue: InMemoryTaskMessageQueue) -> None: - """wait_for_message returns early if message arrives after event creation but before wait.""" - task_id = "task-1" - - # To test the double-check path (lines 223-225), we need a message to arrive - # after the event is created (line 220) but before event.wait() (line 228). - # We simulate this by injecting a message before is_empty is called the second time. - - original_is_empty = queue.is_empty - call_count = 0 - - async def is_empty_with_injection(tid: str) -> bool: - nonlocal call_count - call_count += 1 - if call_count == 2 and tid == task_id: - # Before second check, inject a message - this simulates a message - # arriving between event creation and the double-check - queue._queues[task_id] = deque([QueuedMessage(type="request", message=make_request())]) - return await original_is_empty(tid) - - queue.is_empty = is_empty_with_injection # type: ignore[method-assign] - - # Should return immediately due to double-check finding the message - with anyio.fail_after(1): - await queue.wait_for_message(task_id) - - -class TestResolver: - @pytest.mark.anyio - async def test_set_result_and_wait(self) -> None: - """Test basic set_result and wait flow.""" - resolver: Resolver[str] = Resolver() - - resolver.set_result("hello") - result = await resolver.wait() - - assert result == "hello" - assert resolver.done() - - @pytest.mark.anyio - async def test_set_exception_and_wait(self) -> None: - """Test set_exception raises on wait.""" - resolver: Resolver[str] = Resolver() - - resolver.set_exception(ValueError("test error")) - - with pytest.raises(ValueError, match="test error"): - await resolver.wait() - - assert resolver.done() - - @pytest.mark.anyio - async def test_set_result_when_already_completed_raises(self) -> None: - """Test that set_result raises if resolver already completed.""" - resolver: Resolver[str] = Resolver() - resolver.set_result("first") - - with pytest.raises(RuntimeError, match="already completed"): - resolver.set_result("second") - - @pytest.mark.anyio - async def test_set_exception_when_already_completed_raises(self) -> None: - """Test that set_exception raises if resolver already completed.""" - resolver: Resolver[str] = Resolver() - resolver.set_result("done") - - with pytest.raises(RuntimeError, match="already completed"): - resolver.set_exception(ValueError("too late")) - - @pytest.mark.anyio - async def test_done_returns_false_before_completion(self) -> None: - """Test done() returns False before any result is set.""" - resolver: Resolver[str] = Resolver() - assert resolver.done() is False diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py deleted file mode 100644 index ad4023389e..0000000000 --- a/tests/experimental/tasks/test_request_context.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Tests for the RequestContext.experimental (Experimental class) task validation helpers.""" - -import pytest - -from mcp.server.experimental.request_context import Experimental -from mcp.shared.exceptions import MCPError -from mcp.types import ( - METHOD_NOT_FOUND, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - ClientCapabilities, - ClientTasksCapability, - TaskMetadata, - Tool, - ToolExecution, -) - - -def test_is_task_true_when_metadata_present() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - assert exp.is_task is True - - -def test_is_task_false_when_no_metadata() -> None: - exp = Experimental(task_metadata=None) - assert exp.is_task is False - - -def test_client_supports_tasks_true() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) - assert exp.client_supports_tasks is True - - -def test_client_supports_tasks_false_no_tasks() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.client_supports_tasks is False - - -def test_client_supports_tasks_false_no_capabilities() -> None: - exp = Experimental(_client_capabilities=None) - assert exp.client_supports_tasks is False - - -def test_validate_task_mode_required_with_task_is_valid() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) - assert error is None - - -def test_validate_task_mode_required_without_task_returns_error() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) - assert error is not None - assert error.code == METHOD_NOT_FOUND - assert "requires task-augmented" in error.message - - -def test_validate_task_mode_required_without_task_raises_by_default() -> None: - exp = Experimental(task_metadata=None) - with pytest.raises(MCPError) as exc_info: - exp.validate_task_mode(TASK_REQUIRED) - assert exc_info.value.error.code == METHOD_NOT_FOUND - - -def test_validate_task_mode_forbidden_without_task_is_valid() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) - assert error is None - - -def test_validate_task_mode_forbidden_with_task_returns_error() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) - assert error is not None - assert error.code == METHOD_NOT_FOUND - assert "does not support task-augmented" in error.message - - -def test_validate_task_mode_forbidden_with_task_raises_by_default() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - with pytest.raises(MCPError) as exc_info: - exp.validate_task_mode(TASK_FORBIDDEN) - assert exc_info.value.error.code == METHOD_NOT_FOUND - - -def test_validate_task_mode_none_treated_as_forbidden() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(None, raise_error=False) - assert error is not None - assert "does not support task-augmented" in error.message - - -def test_validate_task_mode_optional_with_task_is_valid() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) - assert error is None - - -def test_validate_task_mode_optional_without_task_is_valid() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) - assert error is None - - -def test_validate_for_tool_with_execution_required() -> None: - exp = Experimental(task_metadata=None) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is not None - assert "requires task-augmented" in error.message - - -def test_validate_for_tool_without_execution() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=None, - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is not None - assert "does not support task-augmented" in error.message - - -def test_validate_for_tool_optional_with_task() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is None - - -def test_can_use_tool_required_with_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) - assert exp.can_use_tool(TASK_REQUIRED) is True - - -def test_can_use_tool_required_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_REQUIRED) is False - - -def test_can_use_tool_optional_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_OPTIONAL) is True - - -def test_can_use_tool_forbidden_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_FORBIDDEN) is True - - -def test_can_use_tool_none_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(None) is True diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py deleted file mode 100644 index 38d7d0a664..0000000000 --- a/tests/experimental/tasks/test_spec_compliance.py +++ /dev/null @@ -1,717 +0,0 @@ -"""Tasks Spec Compliance Tests -=========================== - -Test structure mirrors: https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks.md - -Each section contains tests for normative requirements (MUST/SHOULD/MAY). -""" - -from datetime import datetime, timezone - -import pytest - -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY -from mcp.types import ( - CancelTaskRequestParams, - CancelTaskResult, - CreateTaskResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - PaginatedRequestParams, - ServerCapabilities, - Task, -) - -# Shared test datetime -TEST_DATETIME = datetime(2025, 1, 1, tzinfo=timezone.utc) - - -def _get_capabilities(server: Server) -> ServerCapabilities: - """Helper to get capabilities from a server.""" - return server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) - - -def test_server_without_task_handlers_has_no_tasks_capability() -> None: - """Server without any task handlers has no tasks capability.""" - server: Server = Server("test") - caps = _get_capabilities(server) - assert caps.tasks is None - - -async def _noop_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - -async def _noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - -async def _noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - raise NotImplementedError - - -def test_server_with_list_tasks_handler_declares_list_capability() -> None: - """Server with list_tasks handler declares tasks.list capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_list_tasks=_noop_list_tasks) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is not None - - -def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: - """Server with cancel_task handler declares tasks.cancel capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_cancel_task=_noop_cancel_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.cancel is not None - - -def test_server_with_get_task_handler_declares_requests_tools_call_capability() -> None: - """Server with get_task handler declares tasks.requests.tools.call capability. - (get_task is required for task-augmented tools/call support) - """ - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.requests is not None - assert caps.tasks.requests.tools is not None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -def test_server_without_list_handler_has_no_list_capability() -> None: # pragma: no cover - """Server without list_tasks handler has no tasks.list capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -def test_server_without_cancel_handler_has_no_cancel_capability() -> None: # pragma: no cover - """Server without cancel_task handler has no tasks.cancel capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.cancel is None - - -def test_server_with_all_task_handlers_has_full_capability() -> None: - """Server with all task handlers declares complete tasks capability.""" - server: Server = Server("test") - server.experimental.enable_tasks( - on_list_tasks=_noop_list_tasks, - on_cancel_task=_noop_cancel_task, - on_get_task=_noop_get_task, - ) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is not None - assert caps.tasks.cancel is not None - assert caps.tasks.requests is not None - assert caps.tasks.requests.tools is not None - - -class TestClientCapabilities: - """Clients declare: - - tasks.list — supports listing operations - - tasks.cancel — supports cancellation - - tasks.requests.sampling.createMessage — task-augmented sampling - - tasks.requests.elicitation.create — task-augmented elicitation - """ - - def test_client_declares_tasks_capability(self) -> None: - """Client can declare tasks capability.""" - pytest.skip("TODO") - - -class TestToolLevelNegotiation: - """Tools in tools/list responses include execution.taskSupport with values: - - Not present or "forbidden": No task augmentation allowed - - "optional": Task augmentation allowed at requestor discretion - - "required": Task augmentation is mandatory - """ - - def test_tool_execution_task_forbidden_rejects_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="forbidden" MUST reject task-augmented calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_absent_rejects_task_augmented_call(self) -> None: - """Tool without execution.taskSupport MUST reject task-augmented calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_optional_accepts_normal_call(self) -> None: - """Tool with execution.taskSupport="optional" accepts normal calls.""" - pytest.skip("TODO") - - def test_tool_execution_task_optional_accepts_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="optional" accepts task-augmented calls.""" - pytest.skip("TODO") - - def test_tool_execution_task_required_rejects_normal_call(self) -> None: - """Tool with execution.taskSupport="required" MUST reject non-task calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_required_accepts_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="required" accepts task-augmented calls.""" - pytest.skip("TODO") - - -class TestCapabilityNegotiation: - """Requestors SHOULD only augment requests with a task if the corresponding - capability has been declared by the receiver. - - Receivers that do not declare the task capability for a request type - MUST process requests of that type normally, ignoring any task-augmentation - metadata if present. - """ - - def test_receiver_without_capability_ignores_task_metadata(self) -> None: - """Receiver without task capability MUST process request normally, - ignoring task-augmentation metadata. - """ - pytest.skip("TODO") - - def test_receiver_with_capability_may_require_task_augmentation(self) -> None: - """Receivers that declare task capability MAY return error (-32600) - for non-task-augmented requests, requiring task augmentation. - """ - pytest.skip("TODO") - - -class TestTaskStatusLifecycle: - """Tasks begin in working status and follow valid transitions: - working → input_required → working → terminal - working → terminal (directly) - input_required → terminal (directly) - - Terminal states (no further transitions allowed): - - completed - - failed - - cancelled - """ - - def test_task_begins_in_working_status(self) -> None: - """Tasks MUST begin in working status.""" - pytest.skip("TODO") - - def test_working_to_completed_transition(self) -> None: - """working → completed is valid.""" - pytest.skip("TODO") - - def test_working_to_failed_transition(self) -> None: - """working → failed is valid.""" - pytest.skip("TODO") - - def test_working_to_cancelled_transition(self) -> None: - """working → cancelled is valid.""" - pytest.skip("TODO") - - def test_working_to_input_required_transition(self) -> None: - """working → input_required is valid.""" - pytest.skip("TODO") - - def test_input_required_to_working_transition(self) -> None: - """input_required → working is valid.""" - pytest.skip("TODO") - - def test_input_required_to_terminal_transition(self) -> None: - """input_required → terminal is valid.""" - pytest.skip("TODO") - - def test_terminal_state_no_further_transitions(self) -> None: - """Terminal states allow no further transitions.""" - pytest.skip("TODO") - - def test_completed_is_terminal(self) -> None: - """completed is a terminal state.""" - pytest.skip("TODO") - - def test_failed_is_terminal(self) -> None: - """failed is a terminal state.""" - pytest.skip("TODO") - - def test_cancelled_is_terminal(self) -> None: - """cancelled is a terminal state.""" - pytest.skip("TODO") - - -class TestInputRequiredStatus: - """When a receiver needs information to proceed, it moves the task to input_required. - The requestor should call tasks/result to retrieve input requests. - The task must include io.modelcontextprotocol/related-task metadata in associated requests. - """ - - def test_input_required_status_retrievable_via_tasks_get(self) -> None: - """Task in input_required status is retrievable via tasks/get.""" - pytest.skip("TODO") - - def test_input_required_related_task_metadata_in_requests(self) -> None: - """Task MUST include io.modelcontextprotocol/related-task metadata - in associated requests. - """ - pytest.skip("TODO") - - -class TestCreatingTask: - """Request structure: - {"method": "tools/call", "params": {"name": "...", "arguments": {...}, "task": {"ttl": 60000}}} - - Response (CreateTaskResult): - {"result": {"task": {"taskId": "...", "status": "working", ...}}} - - Receivers may include io.modelcontextprotocol/model-immediate-response in _meta. - """ - - def test_task_augmented_request_returns_create_task_result(self) -> None: - """Task-augmented request MUST return CreateTaskResult immediately.""" - pytest.skip("TODO") - - def test_create_task_result_contains_task_id(self) -> None: - """CreateTaskResult MUST contain taskId.""" - pytest.skip("TODO") - - def test_create_task_result_contains_status_working(self) -> None: - """CreateTaskResult MUST have status=working initially.""" - pytest.skip("TODO") - - def test_create_task_result_contains_created_at(self) -> None: - """CreateTaskResult MUST contain createdAt timestamp.""" - pytest.skip("TODO") - - def test_create_task_result_created_at_is_iso8601(self) -> None: - """createdAt MUST be ISO 8601 formatted.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_ttl(self) -> None: - """CreateTaskResult MAY contain ttl.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_poll_interval(self) -> None: - """CreateTaskResult MAY contain pollInterval.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_status_message(self) -> None: - """CreateTaskResult MAY contain statusMessage.""" - pytest.skip("TODO") - - def test_receiver_may_override_requested_ttl(self) -> None: - """Receiver MAY override requested ttl but MUST return actual value.""" - pytest.skip("TODO") - - def test_model_immediate_response_in_meta(self) -> None: - """Receiver MAY include io.modelcontextprotocol/model-immediate-response - in _meta to provide immediate response while task executes. - """ - # Verify the constant has the correct value per spec - assert MODEL_IMMEDIATE_RESPONSE_KEY == "io.modelcontextprotocol/model-immediate-response" - - # CreateTaskResult can include model-immediate-response in _meta - task = Task( - task_id="test-123", - status="working", - created_at=TEST_DATETIME, - last_updated_at=TEST_DATETIME, - ttl=60000, - ) - immediate_msg = "Task started, processing your request..." - # Note: Must use _meta= (alias) not meta= due to Pydantic alias handling - result = CreateTaskResult( - task=task, - **{"_meta": {MODEL_IMMEDIATE_RESPONSE_KEY: immediate_msg}}, - ) - - # Verify the metadata is present and correct - assert result.meta is not None - assert MODEL_IMMEDIATE_RESPONSE_KEY in result.meta - assert result.meta[MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg - - # Verify it serializes correctly with _meta alias - serialized = result.model_dump(by_alias=True) - assert "_meta" in serialized - assert MODEL_IMMEDIATE_RESPONSE_KEY in serialized["_meta"] - assert serialized["_meta"][MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg - - -class TestGettingTaskStatus: - """Request: {"method": "tasks/get", "params": {"taskId": "..."}} - Response: Returns full Task object with current status and pollInterval. - """ - - def test_tasks_get_returns_task_object(self) -> None: - """tasks/get MUST return full Task object.""" - pytest.skip("TODO") - - def test_tasks_get_returns_current_status(self) -> None: - """tasks/get MUST return current status.""" - pytest.skip("TODO") - - def test_tasks_get_may_return_poll_interval(self) -> None: - """tasks/get MAY return pollInterval.""" - pytest.skip("TODO") - - def test_tasks_get_invalid_task_id_returns_error(self) -> None: - """tasks/get with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_get_nonexistent_task_id_returns_error(self) -> None: - """tasks/get with nonexistent taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestRetrievingResults: - """Request: {"method": "tasks/result", "params": {"taskId": "..."}} - Response: The actual operation result structure (e.g., CallToolResult). - - This call blocks until terminal status. - """ - - def test_tasks_result_returns_underlying_result(self) -> None: - """tasks/result MUST return exactly what underlying request would return.""" - pytest.skip("TODO") - - def test_tasks_result_blocks_until_terminal(self) -> None: - """tasks/result MUST block for non-terminal tasks.""" - pytest.skip("TODO") - - def test_tasks_result_unblocks_on_terminal(self) -> None: - """tasks/result MUST unblock upon reaching terminal status.""" - pytest.skip("TODO") - - def test_tasks_result_includes_related_task_metadata(self) -> None: - """tasks/result MUST include io.modelcontextprotocol/related-task in _meta.""" - pytest.skip("TODO") - - def test_tasks_result_returns_error_for_failed_task(self) -> None: - """tasks/result returns the same error the underlying request - would have produced for failed tasks. - """ - pytest.skip("TODO") - - def test_tasks_result_invalid_task_id_returns_error(self) -> None: - """tasks/result with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestListingTasks: - """Request: {"method": "tasks/list", "params": {"cursor": "optional"}} - Response: Array of tasks with pagination support via nextCursor. - """ - - def test_tasks_list_returns_array_of_tasks(self) -> None: - """tasks/list MUST return array of tasks.""" - pytest.skip("TODO") - - def test_tasks_list_pagination_with_cursor(self) -> None: - """tasks/list supports pagination via cursor.""" - pytest.skip("TODO") - - def test_tasks_list_returns_next_cursor_when_more_results(self) -> None: - """tasks/list MUST return nextCursor when more results available.""" - pytest.skip("TODO") - - def test_tasks_list_cursors_are_opaque(self) -> None: - """Implementers MUST treat cursors as opaque tokens.""" - pytest.skip("TODO") - - def test_tasks_list_invalid_cursor_returns_error(self) -> None: - """tasks/list with invalid cursor MUST return -32602.""" - pytest.skip("TODO") - - -class TestCancellingTasks: - """Request: {"method": "tasks/cancel", "params": {"taskId": "..."}} - Response: Returns the task object with status: "cancelled". - """ - - def test_tasks_cancel_returns_cancelled_task(self) -> None: - """tasks/cancel MUST return task with status=cancelled.""" - pytest.skip("TODO") - - def test_tasks_cancel_terminal_task_returns_error(self) -> None: - """Cancelling already-terminal task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_completed_task_returns_error(self) -> None: - """Cancelling completed task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_failed_task_returns_error(self) -> None: - """Cancelling failed task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_already_cancelled_task_returns_error(self) -> None: - """Cancelling already-cancelled task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_invalid_task_id_returns_error(self) -> None: - """tasks/cancel with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestStatusNotifications: - """Receivers MAY send: {"method": "notifications/tasks/status", "params": {...}} - These are optional; requestors MUST NOT rely on them and SHOULD continue polling. - """ - - def test_receiver_may_send_status_notification(self) -> None: - """Receiver MAY send notifications/tasks/status.""" - pytest.skip("TODO") - - def test_status_notification_contains_task_id(self) -> None: - """Status notification MUST contain taskId.""" - pytest.skip("TODO") - - def test_status_notification_contains_status(self) -> None: - """Status notification MUST contain status.""" - pytest.skip("TODO") - - -class TestTaskManagement: - """- Receivers generate unique task IDs as strings - - Tasks must begin in working status - - createdAt timestamps must be ISO 8601 formatted - - Receivers may override requested ttl but must return actual value - - Receivers may delete tasks after TTL expires - - All task-related messages must include io.modelcontextprotocol/related-task - in _meta except for tasks/get, tasks/list, tasks/cancel operations - """ - - def test_task_ids_are_unique_strings(self) -> None: - """Receivers MUST generate unique task IDs as strings.""" - pytest.skip("TODO") - - def test_multiple_tasks_have_unique_ids(self) -> None: - """Multiple tasks MUST have unique IDs.""" - pytest.skip("TODO") - - def test_receiver_may_delete_tasks_after_ttl(self) -> None: - """Receivers MAY delete tasks after TTL expires.""" - pytest.skip("TODO") - - def test_related_task_metadata_in_task_messages(self) -> None: - """All task-related messages MUST include io.modelcontextprotocol/related-task - in _meta. - """ - pytest.skip("TODO") - - def test_tasks_get_does_not_require_related_task_metadata(self) -> None: - """tasks/get does not require related-task metadata.""" - pytest.skip("TODO") - - def test_tasks_list_does_not_require_related_task_metadata(self) -> None: - """tasks/list does not require related-task metadata.""" - pytest.skip("TODO") - - def test_tasks_cancel_does_not_require_related_task_metadata(self) -> None: - """tasks/cancel does not require related-task metadata.""" - pytest.skip("TODO") - - -class TestResultHandling: - """- Receivers must return CreateTaskResult immediately upon accepting task-augmented requests - - tasks/result must return exactly what the underlying request would return - - tasks/result blocks for non-terminal tasks; must unblock upon reaching terminal status - """ - - def test_create_task_result_returned_immediately(self) -> None: - """Receiver MUST return CreateTaskResult immediately (not after work completes).""" - pytest.skip("TODO") - - def test_tasks_result_matches_underlying_result_structure(self) -> None: - """tasks/result MUST return same structure as underlying request.""" - pytest.skip("TODO") - - def test_tasks_result_for_tool_call_returns_call_tool_result(self) -> None: - """tasks/result for tools/call returns CallToolResult.""" - pytest.skip("TODO") - - -class TestProgressTracking: - """Task-augmented requests support progress notifications using the progressToken - mechanism, which remains valid throughout the task lifetime. - """ - - def test_progress_token_valid_throughout_task_lifetime(self) -> None: - """progressToken remains valid throughout task lifetime.""" - pytest.skip("TODO") - - def test_progress_notifications_sent_during_task_execution(self) -> None: - """Progress notifications can be sent during task execution.""" - pytest.skip("TODO") - - -class TestProtocolErrors: - """Protocol Errors (JSON-RPC standard codes): - - -32600 (Invalid request): Non-task requests to endpoint requiring task augmentation - - -32602 (Invalid params): Invalid/nonexistent taskId, invalid cursor, cancel terminal task - - -32603 (Internal error): Server-side execution failures - """ - - def test_invalid_request_for_required_task_augmentation(self) -> None: - """Non-task request to task-required endpoint returns -32600.""" - pytest.skip("TODO") - - def test_invalid_params_for_invalid_task_id(self) -> None: - """Invalid taskId returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_nonexistent_task_id(self) -> None: - """Nonexistent taskId returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_invalid_cursor(self) -> None: - """Invalid cursor in tasks/list returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_cancel_terminal_task(self) -> None: - """Attempt to cancel terminal task returns -32602.""" - pytest.skip("TODO") - - def test_internal_error_for_server_failure(self) -> None: - """Server-side execution failure returns -32603.""" - pytest.skip("TODO") - - -class TestTaskExecutionErrors: - """When underlying requests fail, the task moves to failed status. - - tasks/get response should include statusMessage explaining failure - - tasks/result returns same error the underlying request would have produced - - For tool calls, isError: true moves task to failed status - """ - - def test_underlying_failure_moves_task_to_failed(self) -> None: - """Underlying request failure moves task to failed status.""" - pytest.skip("TODO") - - def test_failed_task_has_status_message(self) -> None: - """Failed task SHOULD include statusMessage explaining failure.""" - pytest.skip("TODO") - - def test_tasks_result_returns_underlying_error(self) -> None: - """tasks/result returns same error underlying request would produce.""" - pytest.skip("TODO") - - def test_tool_call_is_error_true_moves_to_failed(self) -> None: - """Tool call with isError: true moves task to failed status.""" - pytest.skip("TODO") - - -class TestTaskObject: - """Task Object fields: - - taskId: String identifier - - status: Current execution state - - statusMessage: Optional human-readable description - - createdAt: ISO 8601 timestamp of creation - - ttl: Milliseconds before potential deletion - - pollInterval: Suggested milliseconds between polls - """ - - def test_task_has_task_id_string(self) -> None: - """Task MUST have taskId as string.""" - pytest.skip("TODO") - - def test_task_has_status(self) -> None: - """Task MUST have status.""" - pytest.skip("TODO") - - def test_task_status_message_is_optional(self) -> None: - """Task statusMessage is optional.""" - pytest.skip("TODO") - - def test_task_has_created_at(self) -> None: - """Task MUST have createdAt.""" - pytest.skip("TODO") - - def test_task_ttl_is_optional(self) -> None: - """Task ttl is optional.""" - pytest.skip("TODO") - - def test_task_poll_interval_is_optional(self) -> None: - """Task pollInterval is optional.""" - pytest.skip("TODO") - - -class TestRelatedTaskMetadata: - """Related Task Metadata structure: - {"_meta": {"io.modelcontextprotocol/related-task": {"taskId": "..."}}} - """ - - def test_related_task_metadata_structure(self) -> None: - """Related task metadata has correct structure.""" - pytest.skip("TODO") - - def test_related_task_metadata_contains_task_id(self) -> None: - """Related task metadata contains taskId.""" - pytest.skip("TODO") - - -class TestAccessAndIsolation: - """- Task IDs enable access to sensitive results - - Authorization context binding is essential where available - - For non-authorized environments: strong entropy IDs, strict TTL limits - """ - - def test_task_bound_to_authorization_context(self) -> None: - """Receivers receiving authorization context MUST bind tasks to that context.""" - pytest.skip("TODO") - - def test_reject_task_operations_outside_authorization_context(self) -> None: - """Receivers MUST reject task operations for tasks outside - requestor's authorization context. - """ - pytest.skip("TODO") - - def test_non_authorized_environments_use_secure_ids(self) -> None: - """For non-authorized environments, receivers SHOULD use - cryptographically secure IDs. - """ - pytest.skip("TODO") - - def test_non_authorized_environments_use_shorter_ttls(self) -> None: - """For non-authorized environments, receivers SHOULD use shorter TTLs.""" - pytest.skip("TODO") - - -class TestResourceLimits: - """Receivers should: - - Enforce concurrent task limits per requestor - - Implement maximum TTL constraints - - Clean up expired tasks promptly - """ - - def test_concurrent_task_limit_enforced(self) -> None: - """Receiver SHOULD enforce concurrent task limits per requestor.""" - pytest.skip("TODO") - - def test_maximum_ttl_constraint_enforced(self) -> None: - """Receiver SHOULD implement maximum TTL constraints.""" - pytest.skip("TODO") - - def test_expired_tasks_cleaned_up(self) -> None: - """Receiver SHOULD clean up expired tasks promptly.""" - pytest.skip("TODO") diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 109b30fc77..02a2e033fe 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -48,11 +48,6 @@ _SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") -_TASKS_DEFERRAL = ( - "Tasks are experimental and the spec is being substantially revised; python task behaviour is " - "covered by tests/experimental/tasks/ until the next spec revision settles." -) - @dataclass(frozen=True, kw_only=True) class Divergence: @@ -362,11 +357,6 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic#responses", behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", ), - "protocol:meta:related-task": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", - behavior="Messages may carry related-task _meta associating them with a task.", - deferred=_TASKS_DEFERRAL, - ), "meta:request-to-handler": Requirement( source=f"{SPEC_BASE_URL}/basic#_meta", behavior="The _meta object the client attaches to a request is visible to the server handler.", @@ -1546,190 +1536,6 @@ def __post_init__(self) -> None: ), ), # ═══════════════════════════════════════════════════════════════════════════ - # Tasks (experimental) - # ═══════════════════════════════════════════════════════════════════════════ - "tasks:auth:context-isolation": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-isolation-and-access-control", - behavior=( - "When an authorization context is available, task operations are scoped to the context that " - "created the task: other contexts cannot get it, retrieve its result, cancel it, or see it in " - "tasks/list." - ), - transports=("streamable-http",), - deferred=_TASKS_DEFERRAL, - ), - "tasks:bidirectional": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#definitions", - behavior="Task APIs are bidirectional: the server may create, get, list, and cancel tasks on the client.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:cancel:no-handler-abort": Requirement( - source="sdk", - behavior=( - "tasks/cancel marks the task cancelled without aborting the originating request handler " - "(the spec says receivers SHOULD attempt to stop execution)." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:cancel:remains-cancelled": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", - behavior=( - "After tasks/cancel, the task remains cancelled even if the underlying handler subsequently " - "completes or fails." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:cancel:terminal-rejected": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", - behavior="tasks/cancel on a task already in a terminal state returns Invalid params (-32602).", - deferred=_TASKS_DEFERRAL, - ), - "tasks:cancel:working": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", - behavior="tasks/cancel on a working task transitions it to cancelled and returns the updated task.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:create:ttl-honored": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#ttl-and-resource-management", - behavior=( - "tasks/get responses include the actual ttl applied by the receiver (or null for unlimited); " - "the create-task result carries the same value." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:create:via-tool-call": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#creating-tasks", - behavior="A task-augmented tools/call returns a create-task result instead of the tool result.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:get": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#getting-tasks", - behavior="tasks/get returns the task's current status, ttl, timestamps, and status message.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:lifecycle:initial-working": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-lifecycle", - behavior="A newly created task has status 'working'.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:lifecycle:input-required": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", - behavior=( - "While a task awaits a side-channel client response its status is input_required; once the " - "response arrives the task leaves input_required (typically returning to working)." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:list:invalid-cursor": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", - behavior="tasks/list with an invalid cursor returns Invalid params (-32602).", - deferred=_TASKS_DEFERRAL, - ), - "tasks:list:pagination": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#listing-tasks", - behavior="tasks/list returns created tasks and supports cursor pagination.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:no-capability:ignore-task-param": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-support-and-handling", - behavior=( - "A receiver that did not declare task capability for a request type processes the request " - "normally and returns the ordinary result, ignoring the task augmentation." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:progress:after-create": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-progress-notifications", - behavior=( - "After the create-task result, progress notifications keyed to the original progress token " - "continue to reach the caller until the task is terminal." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:request-cancel:no-task-cancel": Requirement( - source="sdk", - behavior="A cancellation notification for the originating request does not auto-cancel the created task.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:result:failed": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-execution-errors", - behavior="tasks/result for a failed task returns the failure result (isError true).", - deferred=_TASKS_DEFERRAL, - ), - "tasks:result:related-task-meta": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", - behavior="The tasks/result response carries related-task _meta naming the requested task.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:result:terminal": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#result-retrieval", - behavior="tasks/result for a completed task returns the stored result of the original request type.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:side-channel:drain-fifo": Requirement( - source="sdk", - behavior="tasks/result drains queued related-task messages in FIFO order before returning the final result.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:side-channel:drop-on-cancel": Requirement( - source="sdk", - behavior="When a task is cancelled before tasks/result, queued related-task messages are dropped.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:side-channel:elicitation": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", - behavior=( - "An elicitation issued mid-task is delivered through the tasks/result side-channel, and the " - "client's response routes back to the handler." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:side-channel:queue": Requirement( - source="sdk", - behavior=( - "Server-to-client requests with related-task metadata sent while no tasks/result is open are queued." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:side-channel:sampling": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", - behavior=( - "A sampling request issued mid-task is delivered through the tasks/result side-channel, and " - "the client's response routes back to the task." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:side-channel:stream": Requirement( - source="sdk", - behavior=( - "Calling tasks/result while the task is working streams related-task messages as they are " - "produced, then returns the result." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:status-notification": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-notification", - behavior="Task status notifications deliver status updates carrying the full task fields.", - deferred=_TASKS_DEFERRAL, - ), - "tasks:tool-level:forbidden-with-task-32601": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", - behavior=( - "A task-augmented tools/call on a tool that does not support tasks returns Method not found (-32601)." - ), - deferred=_TASKS_DEFERRAL, - ), - "tasks:tool-level:required-no-task-32601": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", - behavior=("A plain tools/call on a tool that requires task augmentation returns Method not found (-32601)."), - deferred=_TASKS_DEFERRAL, - ), - "tasks:unknown-id": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", - behavior="tasks/get, tasks/result, and tasks/cancel for an unknown task id return Invalid params (-32602).", - deferred=_TASKS_DEFERRAL, - ), - # ═══════════════════════════════════════════════════════════════════════════ # Transports (in-suite coverage) # ═══════════════════════════════════════════════════════════════════════════ "transport:streamable-http:stateful": Requirement( diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 5d5f8b8fc9..bef44928ac 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -3,7 +3,6 @@ import pytest from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context pytestmark = pytest.mark.anyio @@ -22,7 +21,6 @@ async def test_progress_token_zero_first_call(): session=mock_session, meta={"progress_token": 0}, lifespan_context=None, - experimental=Experimental(), ) # Create context with our mocks diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3457ec944a..21352b5f2f 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -11,7 +11,6 @@ from mcp.client import Client from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage @@ -1502,7 +1501,6 @@ async def test_report_progress_passes_related_request_id(): session=mock_session, meta={"progress_token": "tok-1"}, lifespan_context=None, - experimental=Experimental(), ) ctx = Context(request_context=request_context, mcp_server=MagicMock()) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a2786d865d..6116a7c7f5 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -76,6 +76,49 @@ async def run_server(): assert received_initialized +@pytest.mark.anyio +async def test_check_client_capability(): + """check_client_capability reflects the capabilities sent by the client at initialize.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + + initialized = anyio.Event() + + async def list_roots_callback(context: Any) -> types.ListRootsResult: # pragma: no cover + return types.ListRootsResult(roots=[]) + + async def run_server(server_session: ServerSession): + async for message in server_session.incoming_messages: # pragma: no branch + if isinstance(message, ClientNotification) and isinstance( + message, InitializedNotification + ): # pragma: no branch + initialized.set() + return + + async with ( + ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions(server_name="mcp", server_version="0.1.0", capabilities=ServerCapabilities()), + ) as server_session, + ClientSession( + server_to_client_receive, + client_to_server_send, + list_roots_callback=list_roots_callback, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server, server_session) + await client_session.initialize() + with anyio.fail_after(5): + await initialized.wait() + + # ClientSession advertises roots when a list_roots_callback is provided. + assert server_session.check_client_capability(types.ClientCapabilities(roots=types.RootsCapability())) + # ClientSession does not advertise sampling without a sampling_callback. + assert not server_session.check_client_capability(types.ClientCapabilities(sampling=types.SamplingCapability())) + + @pytest.mark.anyio async def test_server_capabilities(): notification_options = NotificationOptions() diff --git a/uv.lock b/uv.lock index 5b72e97fce..df63607f40 100644 --- a/uv.lock +++ b/uv.lock @@ -18,10 +18,6 @@ members = [ "mcp-simple-resource", "mcp-simple-streamablehttp", "mcp-simple-streamablehttp-stateless", - "mcp-simple-task", - "mcp-simple-task-client", - "mcp-simple-task-interactive", - "mcp-simple-task-interactive-client", "mcp-simple-tool", "mcp-snippets", "mcp-sse-polling-client", @@ -1268,126 +1264,6 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] -[[package]] -name = "mcp-simple-task" -version = "0.1.0" -source = { editable = "examples/servers/simple-task" } -dependencies = [ - { name = "anyio" }, - { name = "click" }, - { name = "mcp" }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-client" -version = "0.1.0" -source = { editable = "examples/clients/simple-task-client" } -dependencies = [ - { name = "click" }, - { name = "mcp" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-interactive" -version = "0.1.0" -source = { editable = "examples/servers/simple-task-interactive" } -dependencies = [ - { name = "anyio" }, - { name = "click" }, - { name = "mcp" }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-interactive-client" -version = "0.1.0" -source = { editable = "examples/clients/simple-task-interactive-client" } -dependencies = [ - { name = "click" }, - { name = "mcp" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - [[package]] name = "mcp-simple-tool" version = "0.1.0"