diff --git a/docs/mkdocs/zh/tool.md b/docs/mkdocs/zh/tool.md index 2e6b4e4..ca295af 100644 --- a/docs/mkdocs/zh/tool.md +++ b/docs/mkdocs/zh/tool.md @@ -33,6 +33,8 @@ Agent 通过以下步骤动态使用工具: | [Streaming Tools(流式工具)](#streaming-tools流式工具) | 实时预览长文本生成 | 使用 StreamingFunctionTool | 代码生成、文档写作 | | [WebFetchTool](#webfetchtool) | 抓取并文本化单个公网 URL | 实例化 WebFetchTool 并加入 tools | 阅读文档页、RFC、changelog、新闻 | | [WebSearchTool](#websearchtool) | 公网搜索引擎检索 | 实例化 WebSearchTool 并加入 tools | 实时资讯、版本发布、事实/定义查询 | +| [TodoWriteTool](#todowritetool-任务清单工具) | 多步任务规划与进度跟踪(整表替换) | 挂载 `TodoWriteTool` | 短清单、无依赖编排、token 不敏感 | +| [Task 工具族](#task-工具族结构化任务看板) | 结构化任务看板(按 id 增量更新 + 依赖) | 挂载 `TaskToolSet` | 长任务板、跨轮跟踪、blockedBy 依赖 | | [Agent Code Executor](./code_executor.md) | 自动生成并执行代码场景、数据处理场景 | 配置 CodeExecutor | API 自动调用、表格数据处理 | --- @@ -2845,3 +2847,196 @@ DuckDuckGo provider 在命中 instant answer 时,`summary` 字段会包含 DDG - **DuckDuckGo 原始命中**(`ddg_raw_agent`,`dedup_urls=False`):保留 provider 原始召回列表,便于下游处理 - **Google 基线**(`google_agent`,`safe=active`):真实公网搜索 + 服务端单域 `siteSearch` + 客户端多域过滤 + 黑名单 + per-call `lang` 覆盖 - **Google 时效性 Agent**(`google_raw_agent`,`dateRestrict=m6` + `dedup_urls=False`):只保留过去 6 个月索引的结果,适合"最新/what's new"类查询 + +--- + +## TodoWriteTool(任务清单工具) + +`TodoWriteTool` 是框架内置的**结构化任务清单工具**,对齐 Claude Code / DeepAgents 的 `TodoWrite` 语义:模型通过单次 `todo_write` 调用发送**完整、更新后的清单**,工具校验后整体替换上一份清单,并将会话级 state 持久化,从而在多轮 `Runner.run_async` 之间保持计划与进度。 + +适合**步骤较少、无显式依赖边、希望实现简单**的场景。若需要服务端分配 id、按 `taskId` 增量 patch、或 `blockedBy` / `blocks` 依赖编排,请改用下文 [Task 工具族](#task-工具族结构化任务看板)。 + +### 功能特性 + +- **整表替换**:每次调用传入完整 `todos` 数组,新列表**完全覆盖**旧列表(不做智能 merge);唯一合法的清空方式是显式传入 `todos: []` +- **会话级持久化**:清单序列化为 JSON 写入 `tool_context.state["todos[:]"]`(默认前缀 `todos`,**勿用** `temp:`——该前缀会被 `BaseSessionService` 剥离且不持久化) +- **子 Agent 隔离**:state key 追加 `:`,父 / 子 Agent 各自维护独立清单 +- **硬契约校验(代码强制)**:`content` / `activeForm` 非空、至多一个 `in_progress`、`content` 全局唯一;违反时返回 `INVALID_ARGS` / `INVALID_TODOS` +- **Prompt 引导分层**:`DEFAULT_TODO_PROMPT` 经 `process_request` 自动注入 system instruction,描述使用时机与写法;与硬契约分离 +- **响应带回 diff**:成功时返回 `{message, todos, oldTodos}`,便于前端 / CLI 直接渲染当前清单与变更 +- **可选策略钩子**:`nudge_hooks` 只读回调,可在成功响应 `message` 末尾追加策略提示(不得修改清单) +- **全部完成后自动清空**:`clear_on_all_done=True`(默认)时,若传入列表全部为 `completed`,持久化为空列表,避免历史项堆积 + +### TodoWriteTool 参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `state_key_prefix` | `str` | `"todos"` | state key 前缀;勿使用 `temp:` | +| `clear_on_all_done` | `bool` | `True` | 全部为 `completed` 时是否清空持久化列表 | +| `default_nudge` | `str` | 内置文案 | 每次成功响应的基础提示语 | +| `nudge_hooks` | `Optional[List[NudgeHook]]` | `None` | 只读策略钩子列表 | +| `filters_name` / `filters` | — | `None` | 透传给 `BaseTool` 的 Filter | + +**LLM 调用参数**(`todo_write`): + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `todos` | `array` | 是 | 完整清单;每项含 `content`(祈使句)、`activeForm`(进行时)、`status`(`pending` / `in_progress` / `completed`) | + +**成功响应字段**: + +| 字段 | 类型 | 说明 | +|------|------|------| +| `message` | `str` | 基础 nudge + 钩子追加文案 | +| `todos` | `array` | 持久化后的当前清单 | +| `oldTodos` | `array \| null` | 更新前的清单(首次写入为 `null`) | + +### 使用方式 + +```python +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import TodoWriteTool + +agent = LlmAgent( + name="todo_planner", + model=OpenAIModel(model_name="...", api_key="...", base_url="..."), + instruction="你是规划型助手,多步任务请用 todo_write 维护清单。", + tools=[TodoWriteTool()], +) +``` + +服务端 / 审计读取当前清单: + +```python +from trpc_agent_sdk.tools import get_todos, render_todos + +todos = get_todos(session, branch=agent.name) +print(render_todos(todos)) # ✅ / 🔄 / ⬜ 纯文本 checklist +``` + +### TodoWriteTool 与 Task 工具族对比 + +| 维度 | `TodoWriteTool` | `TaskToolSet` | +| --- | --- | --- | +| 工具数量 | 1(`todo_write`) | 4(`task_create` / `task_update` / `task_get` / `task_list`) | +| 更新方式 | 整表替换 | 按 `taskId` 增量 patch | +| 单项标识 | `content`(唯一键) | `id`(服务端分配) | +| 依赖编排 | 无 | `blockedBy` / `blocks`,完成上游自动 unblock | +| state key | `todos[:branch]` | `tasks[:branch]` | +| 并行 tool 调用 | 整表覆盖,天然 last-write-wins | 内置 `task_store_lock` 串行化 RMW | + +> **建议二选一挂载**;同时挂载易让模型混用两套语义。 + +### TodoWriteTool 完整示例 + +见 [examples/todo_tool/run_agent.py](../../../examples/todo_tool/run_agent.py):同一 session 内多轮「规划 → 逐项完成」,每轮用 `get_todos` 读回持久化清单。 + +--- + +## Task 工具族(结构化任务看板) + +`TaskToolSet` 暴露四个工具——`task_create`、`task_update`、`task_get`、`task_list`——对齐 Claude Code v2.1.142+ 的结构化 Task 能力。与 `TodoWriteTool` 的整表替换不同,Task 工具族采用**按服务端分配的 `id` 增量更新**:创建时返回 id,后续用 `task_update` 局部修改状态、字段或依赖边。 + +整个看板序列化为**单个 JSON blob** 写入 `tool_context.state["tasks[:]"]`,跨轮存活;`highwatermark` 记录曾分配的最高 id,软删除(`status: deleted`)后**不会复用 id**。 + +### 功能特性 + +- **增量更新**:`task_create` 分配 id;`task_update` 按 `taskId` patch,无需重传整板 +- **依赖编排**:`addBlockedBy` / `removeBlockedBy`(及 `addBlocks` / `removeBlocks`)维护双向边;上游 `completed` 时自动从下游 `blockedBy` 移除并返回 `unblocked` +- **Token 优化**:`task_list` 只返回摘要(省略 `description`);完整详情用 `task_get` +- **硬契约校验**:`subject` 非空、状态合法、依赖存在、**无环**(`detect_cycle`)、默认**至多一个 `in_progress`**(`enforce_single_in_progress`,可关) +- **并发安全**:`_TaskToolBase` 在 load → mutate → save 外包 `task_store_lock`(按 session + branch),兼容 `parallel_tool_calls=True` 下同批并行调用 +- **Prompt 自动注入**:`DEFAULT_TASK_PROMPT` 多工具挂载时只注入一次 + +### TaskToolSet 构造参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `state_key_prefix` | `str` | `"tasks"` | state key 前缀;勿使用 `temp:` | +| `enforce_single_in_progress` | `bool` | `True` | 设置某任务 `in_progress` 时,若已有其他 `in_progress` 则拒绝 | +| `inject_prompt` | `bool` | `True` | 是否向 system instruction 注入 `DEFAULT_TASK_PROMPT` | + +### 四个工具的 LLM 参数概要 + +**`task_create`** + +| 参数 | 必填 | 说明 | +|------|------|------| +| `subject` | 是 | 短标题(祈使句) | +| `description` | 否 | 自由文本详情 | +| `activeForm` | 否 | 进行时文案 | +| `metadata` | 否 | 扩展键值 | + +返回 `{task: {id, subject}, message}`。 + +**`task_update`** + +| 参数 | 必填 | 说明 | +|------|------|------| +| `taskId` | 是 | 要更新的任务 id | +| `status` | 否 | `pending` / `in_progress` / `completed` / `deleted` | +| `subject` / `description` / `activeForm` / `owner` / `metadata` | 否 | 标量字段 patch | +| `addBlockedBy` / `removeBlockedBy` | 否 | 上游依赖 id 列表 | +| `addBlocks` / `removeBlocks` | 否 | 下游阻塞 id 列表 | + +返回 `{task, unblocked, message}`;`unblocked` 为因本次完成而解除阻塞的 pending 任务 id 列表。 + +**`task_get`**:`taskId`(必填)→ 含 `description` 的完整记录。 + +**`task_list`**:可选 `includeDeleted`;返回 `{tasks, stats}`,摘要不含 `description`。 + +**常见错误码**:`INVALID_ARGS`、`INVALID_DEPENDENCY`、`INVALID_STATUS`、`NOT_FOUND`。 + +### 使用方式 + +```python +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import TaskToolSet + +agent = LlmAgent( + name="task_planner", + model=OpenAIModel(model_name="...", api_key="...", base_url="..."), + instruction="多步项目请用 task_create / task_update 维护看板。", + tools=[TaskToolSet()], + # parallel_tool_calls=True 时,同批多个 task 工具由 task_store_lock 保护 store 一致性 +) +``` + +读回持久化看板(REST / 审计 / demo 收尾): + +```python +from trpc_agent_sdk.tools import get_task_store, render_task_list + +store = get_task_store(session, branch=agent.name) +print(render_task_list(store)) +# ✅ #1 已完成 +# 🔄 #2 进行中 +# ⬜ #3 待办 (blocked by: 2) +``` + +### 依赖与解锁示例 + +```text +#1 设计表结构 + ├──→ #2 实现 API ──→ #3 单元测试 + └──→ #4 编写文档 + +#1 completed → unblocked: ['2', '4'] +#2 completed → unblocked: ['3'] +``` + +### Task 工具族最佳实践 + +- **规划与执行分离**:先 `task_create` 建板并 `addBlockedBy`,再逐项 `in_progress` → `completed` +- **不要编造 id**:只使用 `task_create` 返回的 id +- **并行调用**:开启 `parallel_tool_calls=True` 时,同 board 上的并发 `task_create` / `task_update` 由锁串行化;不同 `branch` 仍并行 +- **与 TodoWrite 二选一**:长板 + 依赖用 Task;短清单用 TodoWrite + +### Task 工具族完整示例 + +| 示例 | 说明 | +| --- | --- | +| [examples/task_tools](../../../examples/task_tools/) | 多轮对话:依赖编排、逐项完成、跨轮 `get_task_store` 读回看板 | +| [examples/task_tools_parallel](../../../examples/task_tools_parallel/) | 验证 `parallel_tool_calls` 与 `task_store_lock`(Phase 1–2 无需 API Key) | diff --git a/examples/task_tools/.env b/examples/task_tools/.env new file mode 100644 index 0000000..dc79139 --- /dev/null +++ b/examples/task_tools/.env @@ -0,0 +1,4 @@ +# Set TRPC_AGENT_API_KEY、TRPC_AGENT_BASE_URL、TRPC_AGENT_MODEL_NAME +TRPC_AGENT_API_KEY=your-api-key +TRPC_AGENT_BASE_URL=your-base-url +TRPC_AGENT_MODEL_NAME=your-model-name diff --git a/examples/task_tools/README.md b/examples/task_tools/README.md new file mode 100644 index 0000000..a578a05 --- /dev/null +++ b/examples/task_tools/README.md @@ -0,0 +1,181 @@ +# Task 工具族(任务看板)示例 + +本示例演示框架内置的 **Task 工具族**(`task_create` / `task_update` / `task_get` / `task_list`),对齐 Claude Code v2.1.142+ 的结构化 Task 能力。与 `TodoWriteTool` 的「整表替换」不同,Task 工具族采用**按 `taskId` 增量更新**模型:任务由服务端分配 `id`,支持 `blockedBy` / `blocks` 依赖编排。任务看板存放在**会话级 state**(key 前缀 `tasks`,**不用** `temp:`),因此可以**跨轮(跨 `Runner.run_async` 调用)存活**。 + +## 关键特性 + +- **增量更新语义**:`task_create` 创建任务并返回服务端分配的 `id`;`task_update` 按 `taskId` 局部 patch 状态 / 字段 / 依赖,不必重传整表。 +- **依赖编排**:`task_update` 的 `addBlockedBy` / `removeBlockedBy`(及 `addBlocks` / `removeBlocks`)维护双向依赖边;上游任务 `completed` 时自动从下游 `blockedBy` 中移除并报告 `unblocked`。 +- **Token 优化**:`task_list` 只返回摘要(`id` / `subject` / `status` / `owner` / `blockedBy`),**刻意省略 `description`**;需要完整详情用 `task_get`。 +- **会话级持久化**:整个看板序列化为单个 JSON blob 写入 `tool_context.state["tasks[:]"]`,随 function-response 事件的 state delta 自动落库,**跨轮存活**。注意:框架会剥离 `temp:` 前缀的 state,因此默认使用无前缀的 `tasks`。 +- **子 Agent 隔离**:state key 追加 `:`,不同 branch(父 / 子 Agent)各自维护独立看板;branch 为空时回退到 agent 名。 +- **硬契约校验(代码强制)**:`subject` 非空、状态合法、依赖引用存在、**依赖无环**(`detect_cycle`)、默认**至多一个 `in_progress`**(`enforce_single_in_progress`,可关);违反返回 `INVALID_ARGS` / `INVALID_DEPENDENCY` / `INVALID_STATUS` / `NOT_FOUND`。 +- **ID 不重用**:`highwatermark` 记录曾分配过的最高 id,软删除(`status: deleted`)后也不会复用。 +- **Prompt 引导分层**:使用时机与状态机建议放在 `DEFAULT_TASK_PROMPT`,由工具 `process_request` 自动注入(多工具挂载只注入一次),与硬契约清晰分层。 + +## 与 `TodoWriteTool` 的关系 + +| 维度 | `todo_write` | Task 工具族 | +| --- | --- | --- | +| 工具数量 | 1 | 4 | +| 更新方式 | 整表替换 | 按 `taskId` 增量 | +| 单项标识 | `content`(唯一键) | `id`(服务端分配) | +| 依赖 | 无 | `blockedBy` / `blocks` | +| state key | `todos[:branch]` | `tasks[:branch]` | +| 适用场景 | 短清单、token 不敏感 | 长任务板、多 Agent、依赖编排 | + +> 建议二选一挂载;同时挂载易让模型混用。 + +## Agent 层级结构说明 + +本例只有一个 `LlmAgent`,挂载 `TaskToolSet` 与文件/Shell 执行工具;`root_agent` 指向 `task_planner`: + +```text +task_planner (LlmAgent) +├── model: OpenAIModel +├── instruction: 工程助手人设(DEFAULT_TASK_PROMPT 由工具 process_request 自动注入) +├── tools: +│ ├── TaskToolSet() → task_create / task_update / task_get / task_list +│ ├── BashTool(cwd=work_dir) +│ ├── WriteTool(cwd=work_dir) +│ └── ReadTool(cwd=work_dir) +└── session: InMemorySessionService(单一 session 跨轮复用) +``` + +关键文件: + +- [examples/task_tools/agent/agent.py](./agent/agent.py):构建 `task_planner`,挂载 `TaskToolSet` +- [examples/task_tools/agent/prompts.py](./agent/prompts.py):规划型人设 instruction(`DEFAULT_TASK_PROMPT` 由 `process_request` 自动追加) +- [examples/task_tools/agent/config.py](./agent/config.py):环境变量读取(LLM 凭据) +- [examples/task_tools/run_agent.py](./run_agent.py):测试入口,在**同一个 session** 内驱动两轮静态站点搭建/优化;`task_update` 将状态设为 `in_progress` / `completed` 时实时渲染看板,每轮结束后用 `get_task_store` 读回持久化看板 + +## 关键代码解释 + +### 1) 挂载与配置(`agent/agent.py`) + +- `TaskToolSet()` 即可直接用,工具名 `task_create` / `task_update` / `task_get` / `task_list`(snake_case,满足部分 provider 的 `^[a-zA-Z0-9_-]+$` 命名约束)。 +- 构造参数: + - `state_key_prefix`(默认 `tasks`):状态 key 前缀;勿用 `temp:`,该前缀不会被 SessionService 持久化。 + - `enforce_single_in_progress`(默认 `True`):设置某任务 `in_progress` 时若已有其他 `in_progress` 则拒绝。 + - `inject_prompt`(默认 `True`):把 `DEFAULT_TASK_PROMPT` 注入 system instruction(多工具挂载只注入一次)。 + +### 2) 跨轮持久化与 CLI 渲染(`run_agent.py`) + +- 所有轮次共用同一个 `session_id`,每轮一次 `runner.run_async`。 +- 工具把看板写进 `tasks:`,随事件 state delta 落库。 +- **实时看板**:每次 `task_update` 成功且状态变为 `in_progress` 或 `completed` 时,demo 从 session 读回 `get_task_store` 并打印 `📋 Current task board:`(仅改依赖边的 `task_update` 不触发)。 +- **轮末看板**:每轮结束后再次用 `get_task_store(session, branch=agent.name)` 读回持久化结果,打印 `📋 Persisted task board:`,证明跨轮存活。 +- 渲染符号与 `render_task_list` 一致: + - `✅` 已完成(completed) + - `🔄` 进行中(in_progress,显示 `activeForm`) + - `⬜` 待办(pending),并标注 `blocked by: ` + +### 3) 硬契约 vs Prompt 引导 + +- **硬契约(代码强制)**:`subject` 非空、状态合法、依赖存在且无环、至多一个 `in_progress`;违反返回明确错误码。 +- **Prompt 引导(鼓励不强制)**:`DEFAULT_TASK_PROMPT` 在挂载工具时经 `process_request` 自动追加到 system instruction。 +- 原则:要强制就加 validator,不要把约束塞进 prompt,两层保持可区分。 + +## 环境与运行 + +### 环境要求 + +- Python 3.12 + +### 安装步骤 + +```bash +git clone https://github.com/trpc-group/trpc-agent-python.git +cd trpc-agent-python +python3 -m venv .venv +source .venv/bin/activate +pip3 install -e . +``` + +### 环境变量要求 + +在 [examples/task_tools/.env](./.env) 中配置(或通过 `export`): + +- `TRPC_AGENT_API_KEY` +- `TRPC_AGENT_BASE_URL` +- `TRPC_AGENT_MODEL_NAME` + +### 运行命令 + +```bash +cd examples/task_tools +python3 run_agent.py +``` + +## Demo 状态转换流程 + +`run_agent.py` 在**同一个 session** 内驱动 2 轮对话:Turn 1 用 `task_create` 规划静态站点并逐项执行;Turn 2 追加 CSS/JS 相关任务并更新 README。典型工具链: + +```text +Turn 1 搭建静态站点 + task_create ×3 → task_update addBlockedBy → task_update in_progress + → Bash / Write 执行 → task_update completed(每步 in_progress / completed 后打印 Current task board) + +Turn 2 优化静态站点 + task_create ×4(id 从 #4 起)→ task_update 依赖与状态 → Bash / Write 执行 + → task_update completed(轮末 Persisted task board 含 #1–#7) +``` + +## 运行结果(示意) + +```text +🆔 Session ID: 71bb460a... (shared across all turns) +📂 Work dir: /path/to/examples/task_tools + +========== 搭建静态站点 ========== +📝 User: 请帮我在当前目录搭建 demo 静态站点 ... +🔧 [Invoke Tool: task_create({'subject': '创建 demo/ 及子目录 css/、js/', ...})] +📊 [Tool Result: created id=1 subject='创建 demo/ 及子目录 css/、js/'] +🔧 [Invoke Tool: task_update({'taskId': '2', 'addBlockedBy': ['1']})] +📊 [Tool Result: updated id=2 status=pending] +🔧 [Invoke Tool: task_update({'taskId': '1', 'status': 'in_progress'})] +📊 [Tool Result: updated id=1 status=in_progress] + +📋 Current task board: + 🔄 #1 创建 demo/ 及子目录 css/、js/ + ⬜ #2 创建 demo/index.html (blocked by: 1) + ⬜ #3 创建 demo/README.md (blocked by: 1) + +🔧 [Invoke Tool: task_update({'taskId': '1', 'status': 'completed'})] +📊 [Tool Result: updated id=1 status=completed unblocked=['2', '3']] + +📋 Current task board: + ✅ #1 创建 demo/ 及子目录 css/、js/ + ⬜ #2 创建 demo/index.html + ⬜ #3 创建 demo/README.md + +📋 Persisted task board: +✅ #1 创建 demo/ 及子目录 css/、js/ +✅ #2 创建 demo/index.html +✅ #3 创建 demo/README.md +---------------------------------------- + +========== 优化静态站点 ========== +🔧 [Invoke Tool: task_create({'subject': '创建 demo/css/style.css', ...})] +📊 [Tool Result: created id=4 subject='创建 demo/css/style.css'] +... (Turn 2 追加 #4–#7,id 延续 Turn 1) + +📋 Persisted task board: +✅ #1 创建 demo/ 及子目录 css/、js/ +✅ #2 创建 demo/index.html +✅ #3 创建 demo/README.md +✅ #4 创建 demo/css/style.css +✅ #5 创建 demo/js/app.js +✅ #6 更新 demo/index.html +✅ #7 更新 demo/README.md +---------------------------------------- +``` + +## 适用场景建议 + +- 长任务板、需要**显式依赖编排**或跨多轮跟踪:用 `TaskToolSet`。 +- 需要服务端 / REST / 审计读取当前看板:调用 `get_task_store(session, branch)`。 + +## 相关示例 + +- [task_tools_parallel](../task_tools_parallel/) — 验证 `parallel_tool_calls=True` 与 `task_store_lock` 下的并行 `task_create` / `task_update`(Phase 1–2 无需 API Key)。 diff --git a/examples/task_tools/agent/__init__.py b/examples/task_tools/agent/__init__.py new file mode 100644 index 0000000..bc6e483 --- /dev/null +++ b/examples/task_tools/agent/__init__.py @@ -0,0 +1,5 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. diff --git a/examples/task_tools/agent/agent.py b/examples/task_tools/agent/agent.py new file mode 100644 index 0000000..14becc3 --- /dev/null +++ b/examples/task_tools/agent/agent.py @@ -0,0 +1,54 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent module for the Task tools example.""" + +import os + +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import LLMModel +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import TaskToolSet +from trpc_agent_sdk.tools.file_tools import BashTool +from trpc_agent_sdk.tools.file_tools import ReadTool +from trpc_agent_sdk.tools.file_tools import WriteTool + +from .config import get_model_config +from .prompts import INSTRUCTION + + +def _create_model() -> LLMModel: + """Create the LLM model used by the demo agent.""" + api_key, url, model_name = get_model_config() + return OpenAIModel(model_name=model_name, api_key=api_key, base_url=url) + + +def create_task_agent(work_dir: str | None = None) -> LlmAgent: + """Build an agent that plans, tracks, and executes multi-step work. + + Args: + work_dir: Working directory for ``Bash`` / ``Write`` / ``Read``. Defaults to ``os.getcwd()``. + + The toolset exposes ``task_create`` / ``task_update`` / ``task_get`` / + ``task_list``. The board is persisted to branch-scoped session state and + survives across ``Runner.run_async`` invocations. + """ + cwd = work_dir or os.getcwd() + return LlmAgent( + name="task_planner", + description=("Engineering assistant that plans and tracks multi-step projects step by step."), + model=_create_model(), + instruction=INSTRUCTION, + tools=[ + TaskToolSet(), + BashTool(cwd=cwd), + WriteTool(cwd=cwd), + ReadTool(cwd=cwd), + ], + ) + + +task_agent = create_task_agent() +root_agent = task_agent diff --git a/examples/task_tools/agent/config.py b/examples/task_tools/agent/config.py new file mode 100644 index 0000000..7f60e29 --- /dev/null +++ b/examples/task_tools/agent/config.py @@ -0,0 +1,19 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +""" Agent config module""" + +import os + + +def get_model_config() -> tuple[str, str, str]: + """Get LLM model config from environment variables.""" + api_key = os.getenv('TRPC_AGENT_API_KEY', '') + url = os.getenv('TRPC_AGENT_BASE_URL', '') + model_name = os.getenv('TRPC_AGENT_MODEL_NAME', '') + if not api_key or not url or not model_name: + raise ValueError('''TRPC_AGENT_API_KEY, TRPC_AGENT_BASE_URL, + and TRPC_AGENT_MODEL_NAME must be set in environment variables''') + return api_key, url, model_name diff --git a/examples/task_tools/agent/prompts.py b/examples/task_tools/agent/prompts.py new file mode 100644 index 0000000..720bea3 --- /dev/null +++ b/examples/task_tools/agent/prompts.py @@ -0,0 +1,10 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompts for the Task tools demo agent.""" + +# Demo-specific persona only. ``DEFAULT_TASK_PROMPT`` is injected automatically +# by the task tools' ``process_request`` when the toolset is registered. +INSTRUCTION = """You are a rigorous engineering assistant that breaks a project and works through it step by step. """ diff --git a/examples/task_tools/run_agent.py b/examples/task_tools/run_agent.py new file mode 100644 index 0000000..54ae7e5 --- /dev/null +++ b/examples/task_tools/run_agent.py @@ -0,0 +1,225 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Demo entry point for the Task tools example. + +The script drives a multi-turn conversation in ONE session so it exercises +the most important properties of the Task toolset: the board lives in +session-level state and survives across ``Runner.run_async`` invocations, +tasks carry server-assigned ids, and dependencies are enforced. + +After each ``task_update`` that sets ``status`` to ``in_progress`` or +``completed``, the demo renders the current board (``✅`` / ``🔄`` / ``⬜``). +At the end of each turn it reads the persisted board back from the session +with :func:`get_task_store` to prove cross-turn persistence. + +Set TRPC_AGENT_API_KEY / TRPC_AGENT_BASE_URL / TRPC_AGENT_MODEL_NAME (see +.env) before running. +""" + +from __future__ import annotations + +import asyncio +import os +import shutil +import uuid + +from dotenv import load_dotenv +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.tools import get_task_store +from trpc_agent_sdk.tools import render_task_list +from trpc_agent_sdk.tools.task_tools._models import TaskStore +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +load_dotenv() + +APP_NAME = "task_agent_demo" +USER_ID = "demo_user" + +TURNS = [ + ( + "搭建静态站点", + "请帮我在当前目录搭建 demo 静态站点。先规划任务并逐步执行:\n" + "1) 创建 demo/ 及子目录 css/、js/\n" + "2) demo/index.html:title 和 h1 均为「Task Demo」\n" + "3) demo/README.md,内容为「Task Demo」\n" + ), + ( + "优化静态站点", + "请优化静态站点,引入 css/style.css 与 js/app.js。用规划任务并逐步执行:\n" + "1) demo/css/style.css:body 居中、浅灰背景、无衬线字体\n" + "2) demo/js/app.js:DOMContentLoaded 时在 console 打印「Task demo loaded」\n" + "3) 更新 demo/index.html,引入 css/style.css 与 js/app.js\n" + "4) 更新 demo/README.md, 描述 index.html的内容\n" + ), +] + + +def _summarise_tool_response(name: str, resp) -> tuple[str, str | None]: + """Compact tool responses; also return ``message`` for nudge visibility.""" + if not isinstance(resp, dict): + return str(resp), None + message = resp.get("message") if isinstance(resp.get("message"), str) else None + if "error" in resp: + return f"error={resp['error']!r}", None + match name: + case "Bash": + stdout = (resp.get("stdout") or "").strip() + stdout = stdout[:80] + ("..." if len(stdout) > 80 else "") + return f"success={resp.get('success')} rc={resp.get('return_code')} stdout={stdout!r}", None + case "Write": + return f"path={resp.get('path')!r} success={resp.get('success')}", message + case "Read": + return f"path={resp.get('path')!r} lines={resp.get('total_lines')}", None + case "task_create": + task = resp.get("task") or {} + return f"created id={task.get('id')} subject={task.get('subject')!r}", message + case "task_update": + task = resp.get("task") or {} + unblocked = resp.get("unblocked") or [] + extra = f" unblocked={unblocked}" if unblocked else "" + return f"updated id={task.get('id')} status={task.get('status')}{extra}", message + case "task_list": + tasks = resp.get("tasks") or [] + return f"list items={len(tasks)} stats={resp.get('stats')}", None + case "task_get": + task = resp.get("task") or {} + return f"get id={task.get('id')} status={task.get('status')}", None + case _: + return str(resp), message + + +def _print_task_board(store: TaskStore, *, indent: str = " ") -> None: + """Render the persisted task board (same glyphs as render_task_list).""" + if not store.tasks: + print(f"{indent}(empty)") + return + print("\n📋 Current task board:") + for line in render_task_list(store).splitlines(): + print(f"{indent}{line}") + print("\n") + +def _should_print_task_board(name: str, resp: dict) -> bool: + """Print board after task_update sets in_progress or completed.""" + if name != "task_update" or "error" in resp: + return False + status = (resp.get("task") or {}).get("status") + return status in ("in_progress", "completed") + + +async def _print_event_parts( + event, + *, + final_text_parts: list[str], + runner: Runner, + session_id: str, + agent_name: str, +) -> None: + """Pretty-print non-partial assistant / tool events.""" + if not event.content or not event.content.parts: + return + if event.partial: + return + for part in event.content.parts: + if part.thought: + continue + if part.function_call: + print(f"🔧 [Invoke Tool: {part.function_call.name}({part.function_call.args})]") + elif part.function_response: + name = part.function_response.name + resp = part.function_response.response + summary, message = _summarise_tool_response(name, resp) + print(f"📊 [Tool Result: {summary}]") + if message: + print(f"💬 [Tool Message: {message}]") + if isinstance(resp, dict) and _should_print_task_board(name, resp): + session = await runner.session_service.get_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + ) + store = get_task_store(session, branch=agent_name) + if not store.tasks: + store = get_task_store(session, branch="") + _print_task_board(store) + elif part.text: + final_text_parts.append(part.text) + + +async def _run_turn(runner: Runner, agent: LlmAgent, *, session_id: str, label: str, query: str) -> None: + """Drive a single user turn through ``runner`` and pretty-print events.""" + print(f"\n========== {label} ==========") + print(f"📝 User: {query}") + + final_text_parts: list[str] = [] + user_content = Content(parts=[Part.from_text(text=query)]) + async for event in runner.run_async( + user_id=USER_ID, + session_id=session_id, + new_message=user_content, + ): + if not event.content or not event.content.parts: + continue + await _print_event_parts( + event, + final_text_parts=final_text_parts, + runner=runner, + session_id=session_id, + agent_name=agent.name, + ) + + if final_text_parts: + print(f"🤖 Assistant: {''.join(final_text_parts)}") + + session = await runner.session_service.get_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + ) + store = get_task_store(session, branch=agent.name) + if not store.tasks: + store = get_task_store(session, branch="") + print("\n📋 Persisted task board:") + print(render_task_list(store) if store.tasks else " (empty)") + print("-" * 40) + + +async def main() -> None: + from agent.agent import create_task_agent + + work_dir = os.getcwd() + demo_dir = os.path.join(work_dir, "demo") + if os.path.isdir(demo_dir): + shutil.rmtree(demo_dir) + print(f"🧹 Cleaned previous {demo_dir}") + + task_agent = create_task_agent(work_dir=work_dir) + runner = Runner( + app_name=APP_NAME, + agent=task_agent, + session_service=InMemorySessionService(), + ) + + session_id = str(uuid.uuid4()) + await runner.session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + state={"user_name": USER_ID}, + ) + print(f"🆔 Session ID: {session_id[:8]}... (shared across all turns)") + print(f"📂 Work dir: {work_dir}") + + for label, query in TURNS: + await _run_turn(runner, task_agent, session_id=session_id, label=label, query=query) + + await runner.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/todo_tool/.env b/examples/todo_tool/.env new file mode 100644 index 0000000..dc79139 --- /dev/null +++ b/examples/todo_tool/.env @@ -0,0 +1,4 @@ +# Set TRPC_AGENT_API_KEY、TRPC_AGENT_BASE_URL、TRPC_AGENT_MODEL_NAME +TRPC_AGENT_API_KEY=your-api-key +TRPC_AGENT_BASE_URL=your-base-url +TRPC_AGENT_MODEL_NAME=your-model-name diff --git a/examples/todo_tool/README.md b/examples/todo_tool/README.md new file mode 100644 index 0000000..1e9b282 --- /dev/null +++ b/examples/todo_tool/README.md @@ -0,0 +1,150 @@ +# TodoWriteTool 任务清单示例 + +本示例演示框架内置的 `TodoWriteTool`,让 LLM Agent 把多步骤任务外显成一份**结构化、可持久化的待办清单**:先规划、再逐项执行、每完成一步翻转状态。清单存放在**会话级 state**(key 前缀 `todos`,**不用** `temp:`),因此可以**跨轮(跨 `Runner.run_async` 调用)存活**,Agent 能从上一轮停下的地方继续。 + +## 关键特性 + +- **整表替换语义**:模型每次调用 `todo_write` 都发送**完整的新列表**,整体替换旧列表,不做智能 merge(与 Claude Code / DeepAgents 路线一致,简单鲁棒)。 +- **会话级持久化**:清单写入 `tool_context.state["todos[:]"]`,随 function-response 事件的 state delta 自动落库,**跨轮存活**,无需额外存储机制。注意:框架会剥离 `temp:` 前缀的 state,因此 TodoWrite 默认使用无前缀的 `todos`。 +- **子 Agent 隔离**:state key 追加 `:`,不同 branch(父 / 子 Agent)各自维护独立清单,互不覆盖;branch 为空时回退到 agent 名。 +- **硬契约校验(代码强制)**:`content` / `activeForm` 非空、**至多一个 `in_progress`**、`content` 全表唯一,违反则工具调用返回 `INVALID_TODOS` 错误。 +- **防误删守卫**:缺失 `todos` 字段或 `todos: null` 一律报错;唯一合法的清空手势是显式空数组 `todos: []`,避免上游丢字段误清整张计划。 +- **结构化回显**:返回 `{message, todos, oldTodos}` —— `message` 是给模型的 nudge,`todos/oldTodos` 供前端 / CLI 直接渲染当前列表与 diff,无需再查 session。 +- **clear_on_all_done**:全部 `completed` 时默认清空列表,避免已完成项跨轮无限堆积(本 demo 显式设为 `False` 以便展示最终全完成态)。 +- **NudgeHook 策略回调**:在持久化后、返回前调用的只读回调,返回的非空字符串追加进 message,可用于「预算告警 / 验证提醒 / 死循环检测」等策略而不改工具本体。 +- **Prompt 引导分层**:风格建议(恰好一个 in_progress、完成立即标记、不要复述整张清单)放在 `DEFAULT_TODO_PROMPT`,与硬契约清晰分层。 + +## Agent 层级结构说明 + +本例只有一个 `LlmAgent`,挂载单个 `TodoWriteTool`;`root_agent` 指向 `todo_planner`: + +```text +todo_planner (LlmAgent) +├── model: OpenAIModel +├── instruction: 规划型人设(DEFAULT_TODO_PROMPT 由工具 process_request 自动注入) +├── tools: +│ └── TodoWriteTool(clear_on_all_done=False, nudge_hooks=[_all_done_nudge_hook]) +└── session: InMemorySessionService(单一 session 跨轮复用) +``` + +关键文件: + +- [examples/todo_tool/agent/agent.py](./agent/agent.py):构建 `todo_planner`,挂载 `TodoWriteTool` 并演示自定义只读 `NudgeHook` +- [examples/todo_tool/agent/prompts.py](./agent/prompts.py):规划型人设 instruction(`DEFAULT_TODO_PROMPT` 由 `TodoWriteTool.process_request` 自动追加) +- [examples/todo_tool/agent/config.py](./agent/config.py):环境变量读取(LLM 凭据) +- [examples/todo_tool/run_agent.py](./run_agent.py):测试入口,在**同一个 session** 内驱动「规划 → 逐项完成」多轮对话,并在每轮后用 `get_todos` 读回持久化清单渲染成 ASCII checklist + +## 关键代码解释 + +### 1) 挂载与配置(`agent/agent.py`) + +- `TodoWriteTool()` 即可直接用,工具名默认 `todo_write`(snake_case,满足部分 provider 的 `^[a-zA-Z0-9_-]+$` 命名约束)。 +- 构造参数: + - `clear_on_all_done`(默认 `True`):全部完成时清空列表;本 demo 设为 `False` 以便看到最终「全部 ✅」的清单。 + - `state_key_prefix`(默认 `todos`):状态 key 前缀;勿用 `temp:`,该前缀不会被 SessionService 持久化。 + - `default_nudge`:每次成功响应追加的基础提醒。 + - `nudge_hooks`:只读策略回调列表。 + +### 2) 只读 NudgeHook(`_all_done_nudge_hook`) + +```python +def _all_done_nudge_hook(old, new): + if len(new) < 3: + return None + if not all(item.status == TodoStatus.COMPLETED for item in new): + return None + return "Reminder: all tasks are marked completed. ..." +``` + +- 签名为 `(old: list[TodoItem], new: list[TodoItem]) -> Optional[str]`,在持久化后、返回前被调用。 +- 返回的非空字符串会追加进工具响应的 `message`,让模型看到。 +- 约定 **只读**:不可修改清单;本例与 Go `examples/todo` 对齐——当清单 ≥3 项且全部 `completed` 时,提醒模型在收尾前简要总结结果。 + +### 3) 跨轮持久化与读回(`run_agent.py`) + +- 所有轮次共用同一个 `session_id`,每轮一次 `runner.run_async`。 +- 工具把清单写进 `todos:`,随事件 state delta 落库。 +- 每轮结束后用 `get_todos(session, branch=agent.name)` 把**持久化后的清单**读回来,证明它跨轮存活,再用 `render_todos` 渲染: + - `✅` 已完成(completed) + - `🔄` 进行中(in_progress,显示 `activeForm`) + - `⬜` 待办(pending) + +### 4) 硬契约 vs Prompt 引导 + +- **硬契约(代码强制)**:`validate_todos` —— 字段非空、至多一个 `in_progress`、`content` 唯一;违反返回 `INVALID_TODOS`。 +- **防误删守卫**:缺字段 / `null` 报错,仅 `todos: []` 可清空。 +- **Prompt 引导(鼓励不强制)**:`DEFAULT_TODO_PROMPT` 在挂载工具时经 `process_request` 自动追加到 system instruction。 +- 原则:要强制就加 validator,不要把约束塞进 prompt,两层保持可区分。 + +## 环境与运行 + +### 环境要求 + +- Python 3.12 + +### 安装步骤 + +```bash +git clone https://github.com/trpc-group/trpc-agent-python.git +cd trpc-agent-python +python3 -m venv .venv +source .venv/bin/activate +pip3 install -e . +``` + +### 环境变量要求 + +在 [examples/todo_tool/.env](./.env) 中配置(或通过 `export`): + +- `TRPC_AGENT_API_KEY` +- `TRPC_AGENT_BASE_URL` +- `TRPC_AGENT_MODEL_NAME` + +### 运行命令 + +```bash +cd examples/todo_tool +python3 run_agent.py +``` + +## 运行结果(示意) + +```text +🆔 Session ID: 1a2b3c4d... (shared across all turns) + +========== 规划任务 ========== +📝 User: 请规划一个三步任务并用 todo_write 记录:1) 初始化项目骨架 2) 实现核心业务逻辑 3) 编写并跑通单元测试。先整体规划,把第一步设为进行中,其余为待办。 +🤖 Assistant: +🔧 [Invoke Tool: todo_write({'todos': [{'content': '初始化项目骨架', 'activeForm': '正在初始化项目骨架', 'status': 'in_progress'}, {'content': '实现核心业务逻辑', 'activeForm': '正在实现核心业务逻辑', 'status': 'pending'}, {'content': '编写并跑通单元测试', 'activeForm': '正在编写单元测试', 'status': 'pending'}]})] +📊 [Tool Result: items=3 old_items=0 [in_progress:初始化项目骨架, pending:实现核心业务逻辑, pending:编写并跑通单元测试]] +我已经把任务拆成三步,现在开始第一步:初始化项目骨架。 + +📋 Persisted checklist: +🔄 正在初始化项目骨架 +⬜ 实现核心业务逻辑 +⬜ 编写并跑通单元测试 +---------------------------------------- + +========== 完成第 1 步 ========== +📝 User: 第一步『初始化项目骨架』已经完成了,请更新清单并开始下一步。 +🤖 Assistant: +🔧 [Invoke Tool: todo_write({'todos': [{'content': '初始化项目骨架', 'activeForm': '正在初始化项目骨架', 'status': 'completed'}, {'content': '实现核心业务逻辑', 'activeForm': '正在实现核心业务逻辑', 'status': 'in_progress'}, {'content': '编写并跑通单元测试', 'activeForm': '正在编写单元测试', 'status': 'pending'}]})] +📊 [Tool Result: items=3 old_items=3 [completed:初始化项目骨架, in_progress:实现核心业务逻辑, pending:编写并跑通单元测试]] +第一步已完成,现在进行第二步:实现核心业务逻辑。 + +📋 Persisted checklist: +✅ 初始化项目骨架 +🔄 正在实现核心业务逻辑 +⬜ 编写并跑通单元测试 +---------------------------------------- + +... (第 2、3 步依次翻转,最终全部 ✅ 完成) +``` + +## 适用场景建议 + +- 复杂多步任务(代码生成、多文件改造、调研、部署)需要**规划外显 + 进度可视 + 可控收尾**:直接复用本示例。 +- 需要把清单接入前端 / AG-UI 实时渲染:消费工具响应里的 `todos` / `oldTodos`(已是纯 JSON 结构)。 +- 需要服务端 / REST / 审计读取当前清单:调用 `get_todos(session, branch)`。 +- 需要在工具调用前后插入日志、审计、参数校验:把 `TodoWriteTool(filters_name=[...])` 与 `before_tool_callback` / `after_tool_callback` 组合使用。 +- 需要「未完成项不允许收尾」的强约束:用 `LlmAgent` 的 `after_model_callback` / `before_model_callback` 实现 enforcer(参考实现方案文档的 Phase 2)。 diff --git a/examples/todo_tool/agent/__init__.py b/examples/todo_tool/agent/__init__.py new file mode 100644 index 0000000..bc6e483 --- /dev/null +++ b/examples/todo_tool/agent/__init__.py @@ -0,0 +1,5 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. diff --git a/examples/todo_tool/agent/agent.py b/examples/todo_tool/agent/agent.py new file mode 100644 index 0000000..ef718cb --- /dev/null +++ b/examples/todo_tool/agent/agent.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent module for the TodoWriteTool example.""" + +import os +from typing import List +from typing import Optional + +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import LLMModel +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import TodoItem +from trpc_agent_sdk.tools import TodoStatus +from trpc_agent_sdk.tools import TodoWriteTool +from trpc_agent_sdk.tools.file_tools import BashTool +from trpc_agent_sdk.tools.file_tools import ReadTool +from trpc_agent_sdk.tools.file_tools import WriteTool + +from .config import get_model_config +from .prompts import INSTRUCTION + + +def _create_model() -> LLMModel: + """Create the LLM model used by the demo agent.""" + api_key, url, model_name = get_model_config() + return OpenAIModel(model_name=model_name, api_key=api_key, base_url=url) + + +def _all_done_nudge_hook(old: List[TodoItem], new: List[TodoItem]) -> Optional[str]: + """Example read-only NudgeHook (aligned with Go ``examples/todo``). + + When the plan has at least three items and every item is ``completed``, + append a reminder so the model summarises the outcome before wrapping up. + """ + if len(new) < 3: + return None + if not all(item.status == TodoStatus.COMPLETED for item in new): + return None + return ("Reminder: all tasks are marked completed. " + "Before finishing, briefly summarise the outcome for the user.") + + +def create_todo_agent(work_dir: str | None = None) -> LlmAgent: + """Build an agent that plans, tracks, and executes multi-step work. + + Args: + work_dir: Working directory for ``Bash`` / ``Write`` / ``Read``. Defaults to ``os.getcwd()``. + + ``clear_on_all_done=False`` keeps completed items visible so the demo + can render the final all-done checklist; production agents may keep + the default (``True``) to avoid stale items piling up across turns. + """ + cwd = work_dir or os.getcwd() + todo_tool = TodoWriteTool( + clear_on_all_done=False, + nudge_hooks=[_all_done_nudge_hook], + ) + bash_tool = BashTool(cwd=cwd) + write_tool = WriteTool(cwd=cwd) + read_tool = ReadTool(cwd=cwd) + + return LlmAgent( + name="todo_planner", + description="Engineering assistant that plans and tracks multi-step tasks.", + model=_create_model(), + instruction=INSTRUCTION, + tools=[todo_tool, bash_tool, write_tool, read_tool], + ) + + +todo_agent = create_todo_agent() +root_agent = todo_agent diff --git a/examples/todo_tool/agent/config.py b/examples/todo_tool/agent/config.py new file mode 100644 index 0000000..7f60e29 --- /dev/null +++ b/examples/todo_tool/agent/config.py @@ -0,0 +1,19 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +""" Agent config module""" + +import os + + +def get_model_config() -> tuple[str, str, str]: + """Get LLM model config from environment variables.""" + api_key = os.getenv('TRPC_AGENT_API_KEY', '') + url = os.getenv('TRPC_AGENT_BASE_URL', '') + model_name = os.getenv('TRPC_AGENT_MODEL_NAME', '') + if not api_key or not url or not model_name: + raise ValueError('''TRPC_AGENT_API_KEY, TRPC_AGENT_BASE_URL, + and TRPC_AGENT_MODEL_NAME must be set in environment variables''') + return api_key, url, model_name diff --git a/examples/todo_tool/agent/prompts.py b/examples/todo_tool/agent/prompts.py new file mode 100644 index 0000000..5c1ab3f --- /dev/null +++ b/examples/todo_tool/agent/prompts.py @@ -0,0 +1,11 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompts for the TodoWrite demo agent.""" + +# Demo-specific persona only. ``DEFAULT_TODO_PROMPT`` is injected automatically +# by ``TodoWriteTool.process_request`` when the tool is registered on the agent. +INSTRUCTION = ("You are a rigorous engineering assistant that breaks a task and works " + "through it step by step.\n") diff --git a/examples/todo_tool/run_agent.py b/examples/todo_tool/run_agent.py new file mode 100644 index 0000000..1f7d0b5 --- /dev/null +++ b/examples/todo_tool/run_agent.py @@ -0,0 +1,185 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Demo entry point for the TodoWriteTool example. + +The script drives a multi-turn conversation in ONE session so it exercises +the most important property of ``TodoWriteTool``: the checklist lives in +session-level state and survives across ``Runner.run_async`` invocations. + +After each ``todo_write`` the demo renders the current checklist (``✅`` / +``🔄`` / ``⬜``). At the end of each turn it reads the persisted list back +from the session with :func:`get_todos` to prove cross-turn persistence. + +Set TRPC_AGENT_API_KEY / TRPC_AGENT_BASE_URL / TRPC_AGENT_MODEL_NAME (see +.env) before running. +""" + +from __future__ import annotations + +import asyncio +import os +import shutil +import uuid + +from dotenv import load_dotenv +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.tools import get_todos +from trpc_agent_sdk.tools import render_todos +from trpc_agent_sdk.tools import TodoItem +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +load_dotenv() + +APP_NAME = "todo_agent_demo" +USER_ID = "demo_user" + +TURNS = [ + ( + "搭建静态站点", + "请帮我在当前目录搭建一个 demo 静态站点,要求:\n" + "1) 创建 demo/ 及子目录 css/、js/\n" + "2) demo/index.html:title 和 h1 均为「Todo Demo」\n" + "3) 生成 demo/README.md 文件,内容为「Todo Demo」\n" + ), + ( + "优化静态站点", + "请优化静态站点,引入 css/style.css 与 js/app.js,要求:\n" + "1) demo/css/style.css:body 居中、浅灰背景、无衬线字体\n" + "2) demo/js/app.js:DOMContentLoaded 时在 console 打印「Todo HITL demo loaded」\n" + "3) 更新 demo/README.md 文件,内容为「Todo Demo with css and js」\n" + ), +] + + +def _summarise_tool_response(name: str, resp) -> tuple[str, str | None]: + """Compact tool responses; also return ``message`` for nudge visibility.""" + if not isinstance(resp, dict): + return str(resp), None + message = resp.get("message") if isinstance(resp.get("message"), str) else None + if "error" in resp: + return f"error={resp['error']!r}", None + match name: + case "Bash": + stdout = (resp.get("stdout") or "").strip() + stdout = stdout[:80] + ("..." if len(stdout) > 80 else "") + return f"success={resp.get('success')} rc={resp.get('return_code')} stdout={stdout!r}", None + case "Write": + return f"path={resp.get('path')!r} success={resp.get('success')}", message + case "Read": + return f"path={resp.get('path')!r} lines={resp.get('total_lines')}", None + case "todo_write": + old = resp.get("oldTodos") or [] + return f"items={len(resp.get('todos') or [])} old_items={len(old)}", message + case _: + return str(resp), message + + +def _print_todo_checklist(resp: dict, *, indent: str = " ") -> None: + """Render todo_write response as a checklist (same format as end-of-turn summary).""" + raw = resp.get("todos") or [] + if not raw: + print(f"{indent}(empty)") + return + try: + items = [TodoItem.model_validate(t) for t in raw] + except Exception: + return + print("📋 Current checklist:") + for line in render_todos(items).splitlines(): + print(f"{indent}{line}") + + +def _print_event_parts(event, *, final_text_parts: list[str]) -> None: + """Pretty-print non-partial assistant / tool events.""" + if not event.content or not event.content.parts: + return + if event.partial: + return + for part in event.content.parts: + if part.thought: + continue + if part.function_call: + print(f"🔧 [Invoke Tool: {part.function_call.name}({part.function_call.args})]") + elif part.function_response: + name = part.function_response.name + resp = part.function_response.response + summary, message = _summarise_tool_response(name, resp) + print(f"📊 [Tool Result: {summary}]") + if message: + print(f"💬 [Tool Message: {message}]") + if name == "todo_write" and isinstance(resp, dict) and "error" not in resp: + _print_todo_checklist(resp) + elif part.text: + final_text_parts.append(part.text) + + +async def _run_turn(runner: Runner, agent: LlmAgent, *, session_id: str, label: str, query: str) -> None: + """Drive a single user turn through ``runner`` and pretty-print events.""" + print(f"\n========== {label} ==========") + print(f"📝 User: {query}") + + final_text_parts: list[str] = [] + user_content = Content(parts=[Part.from_text(text=query)]) + async for event in runner.run_async( + user_id=USER_ID, + session_id=session_id, + new_message=user_content, + ): + if not event.content or not event.content.parts: + continue + _print_event_parts(event, final_text_parts=final_text_parts) + + if final_text_parts: + print(f"🤖 Assistant: {''.join(final_text_parts)}") + + session = await runner.session_service.get_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + ) + todos = get_todos(session, branch=agent.name) or get_todos(session, branch="") + print("\n📋 Persisted checklist:") + print(render_todos(todos) if todos else " (empty)") + print("-" * 40) + + +async def main() -> None: + from agent.agent import create_todo_agent + + work_dir = os.getcwd() + demo_dir = os.path.join(work_dir, "demo") + if os.path.isdir(demo_dir): + shutil.rmtree(demo_dir) + print(f"🧹 Cleaned previous {demo_dir}") + + todo_agent = create_todo_agent(work_dir=work_dir) + runner = Runner( + app_name=APP_NAME, + agent=todo_agent, + session_service=InMemorySessionService(), + ) + + session_id = str(uuid.uuid4()) + await runner.session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + state={"user_name": USER_ID}, + ) + print(f"🆔 Session ID: {session_id[:8]}... (shared across all turns)") + print(f"📂 Work dir: {work_dir}") + + for label, query in TURNS: + await _run_turn(runner, todo_agent, session_id=session_id, label=label, query=query) + + await runner.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/todo_tool_with_human_in_the_loop/.env b/examples/todo_tool_with_human_in_the_loop/.env new file mode 100644 index 0000000..dc79139 --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/.env @@ -0,0 +1,4 @@ +# Set TRPC_AGENT_API_KEY、TRPC_AGENT_BASE_URL、TRPC_AGENT_MODEL_NAME +TRPC_AGENT_API_KEY=your-api-key +TRPC_AGENT_BASE_URL=your-base-url +TRPC_AGENT_MODEL_NAME=your-model-name diff --git a/examples/todo_tool_with_human_in_the_loop/README.md b/examples/todo_tool_with_human_in_the_loop/README.md new file mode 100644 index 0000000..42b0472 --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/README.md @@ -0,0 +1,193 @@ +# TodoWriteTool + Human-in-the-Loop 示例 + +本示例在 [TodoWriteTool 任务清单示例](../todo_tool/README.md) 之上,演示 **Human-in-the-Loop(人机协同)** 与 **真实文件执行** 的组合流程: + +1. Agent 提交多步计划 → `request_todo_plan_approval` 触发 `LongRunningEvent` 暂停; +2. 人工审阅 / 修改计划(demo 中模拟为**追加一条 todo**)→ 通过 `FunctionResponse` 恢复; +3. Agent 调用 `todo_write` 持久化清单,再用 `Bash` / `Write` / `Read` 逐步执行; +4. 清单跨轮保存在 session state,每轮结束后 CLI 读回并渲染 ASCII checklist。 + +Demo 场景:在当前工作目录搭建 `demo/` 静态站点(目录结构 + HTML + CSS + JS),审批阶段人工追加「生成 README 文件」任务。 + +## 关键特性 + +- **计划审批门禁**:`request_todo_plan_approval` 经 `LongRunningFunctionTool` 包装,返回 `pending_approval` 后触发 `LongRunningEvent`,Agent 暂停等待人工介入。 +- **审批时可改计划**:人工可在 `FunctionResponse.response` 中修改 `todos` 列表(本 demo 自动追加「生成 README 文件」),Agent 恢复后按**更新后的清单**执行。 +- **审批后持久化 + 逐步执行**:恢复后先 `todo_write`,再 `Bash` / `Write` / `Read` 落地文件;工具成功后才标记 `completed`。 +- **文件执行能力**:挂载 `BashTool`、`WriteTool`、`ReadTool`,工作目录默认为 `run_agent.py` 的当前目录。 +- **继承 TodoWrite 能力**:整表替换、硬契约校验、NudgeHook、跨轮持久化等与 `todo_tool` 示例一致。 +- **启动清理**:每次运行前删除已有 `demo/` 目录,便于重复验证。 + +## Agent 层级结构 + +```text +todo_planner (LlmAgent) +├── model: OpenAIModel +├── instruction: 工程助手人设(DEFAULT_TODO_PROMPT 由 TodoWriteTool 自动注入) +├── tools: +│ ├── request_todo_plan_approval (LongRunningFunctionTool) ← 计划审批 +│ ├── TodoWriteTool(clear_on_all_done=False, nudge_hooks=[...]) ← 清单持久化 +│ ├── BashTool(cwd=work_dir) ← mkdir 等 shell 操作 +│ ├── WriteTool(cwd=work_dir) ← 写 HTML / CSS / JS / README +│ └── ReadTool(cwd=work_dir) ← 读回验证 +└── session: InMemorySessionService(单一 session) +``` + +关键文件: + +- [agent/tools.py](./agent/tools.py):`request_todo_plan_approval`(校验 + 返回 `preview`) +- [agent/agent.py](./agent/agent.py):组装 Agent 与全部工具 +- [agent/prompts.py](./agent/prompts.py):Agent 人设 instruction +- [agent/config.py](./agent/config.py):LLM 环境变量 +- [run_agent.py](./run_agent.py):驱动对话、捕获 HITL 事件、模拟人工改计划、恢复执行 + +## 关键代码解释 + +### 1) 审批工具(`agent/tools.py`) + +```python +async def request_todo_plan_approval(todos: list, summary: str = "") -> dict: + # 与 todo_write 相同的 validate_todos 硬契约 + return { + "status": "pending_approval", + "todos": [...], + "preview": render_todos(items), # ASCII checklist,供 CLI / 前端展示 + ... + } +``` + +- 不写入 session state,仅提交待审计划。 +- `preview` 可直接展示给审批人。 + +### 2) 人工改计划并恢复(`run_agent.py` · `_build_approval_resume`) + +本 demo **不在终端等待真实输入**,而是在审批回调里模拟人工编辑计划: + +```python +todos = list(response_data.get("todos") or []) +todos.append({ + "content": "生成 README 文件", + "activeForm": "正在生成 README 文件", + "status": "pending", +}) +response_data["status"] = "approved" +response_data["todos"] = todos # 把修改后的完整列表回填给 Agent +``` + +Agent 恢复后会看到 `message` 里说明「已追加 README todo」,并应用更新后的 `todos` 调用 `todo_write`。 + +生产环境 / AG-UI 前端:捕获 `LongRunningEvent`,让用户在前端增删改 todo 后,构造同样的 `FunctionResponse` 提交即可。参见 [llmagent_with_human_in_the_loop](../llmagent_with_human_in_the_loop/README.md)。 + +### 3) HITL 事件捕获(`run_agent.py` · `_consume_run`) + +```python +async for event in runner.run_async(...): + if isinstance(event, LongRunningEvent): + # 展示 function_call.args 与 preview,Agent 暂停 + captured = event + +# 构造 FunctionResponse 恢复 +resume_content = Content(role="user", parts=[ + Part(function_response=FunctionResponse( + id=event.function_response.id, + name=event.function_response.name, + response=response_data, # status=approved + 修改后的 todos + )) +]) +await runner.run_async(..., new_message=resume_content) +``` + +### 4) 文件工具(`agent/agent.py`) + +```python +cwd = work_dir or os.getcwd() +bash_tool = BashTool(cwd=cwd) +write_tool = WriteTool(cwd=cwd) +read_tool = ReadTool(cwd=cwd) +``` + +工具名必须为 `Bash` / `Write` / `Read`(区分大小写),由框架 schema 暴露给模型。 + +## 环境与运行 + +### 环境要求 + +- Python 3.12 + +### 安装步骤 + +```bash +git clone https://github.com/trpc-group/trpc-agent-python.git +cd trpc-agent-python +python3 -m venv .venv +source .venv/bin/activate +pip3 install -e . +``` + +### 环境变量 + +在 [examples/todo_tool_with_human_in_the_loop/.env](./.env) 中配置(或通过 `export`): + +- `TRPC_AGENT_API_KEY` +- `TRPC_AGENT_BASE_URL` +- `TRPC_AGENT_MODEL_NAME` + +### 运行命令 + +```bash +cd examples/todo_tool_with_human_in_the_loop +python3 run_agent.py +``` + +运行后会在**当前目录**生成 `demo/` 项目;再次运行会先清理旧目录。 + +可在 [run_agent.py](./run_agent.py) 的 `TURNS` 中扩展更多轮次(`(label, query)` 二元组)。 + +## 运行结果(示意) + +```text +🧹 Cleaned previous .../demo +🆔 Session ID: 57780a8a... (shared across all turns) +📂 Work dir: .../todo_tool_with_human_in_the_loop + +========== 搭建静态站点 ========== +📝 User: 请帮我在当前目录搭建一个 demo 静态站点... +🔧 [Invoke Tool: request_todo_plan_approval({...})] + +🔄 [Long-running operation detected — waiting for human approval] + Proposed checklist: + ⬜ 创建 demo/、demo/css/、demo/js/ 目录结构 + ⬜ 创建 demo/index.html ... + ⬜ 创建 demo/css/style.css ... + ⬜ 创建 demo/js/app.js ... + +👤 [Human approval with plan edit] + Decision: approved by demo_user + Edit: added todo → 生成 README 文件 + Updated checklist: + ⬜ 创建 demo/ ... + ... + ⬜ 生成 README 文件 ← 人工追加 + +🔄 Resuming agent after human approval... +🔧 [Invoke Tool: todo_write({...})] +🔧 [Invoke Tool: Bash({'command': 'mkdir -p demo/css demo/js'})] +🔧 [Invoke Tool: Write({'path': 'demo/index.html', ...})] +... +📋 Persisted checklist: +✅ 创建 demo/ ... +✅ 生成 README 文件 +---------------------------------------- +``` + +## 适用场景 + +| 需求 | 建议 | +|------|------| +| 执行前需人工确认 / 修改计划 | 复用 `request_todo_plan_approval` + `FunctionResponse` 改 `todos` | +| 计划 + 文件操作 + 进度跟踪 | 本示例(TodoWrite + File Tools + HITL) | +| AG-UI / REST 接入审批 UI | 捕获 `LongRunningEvent`,前端回填 `FunctionResponse` | +| 仅需 TodoWrite,无审批 | [todo_tool](../todo_tool/README.md) | +| 仅需 HITL 机制演示 | [llmagent_with_human_in_the_loop](../llmagent_with_human_in_the_loop/README.md) | + +相关文档:[human_in_the_loop.md](../../docs/mkdocs/zh/human_in_the_loop.md) · [tool.md(File Tools)](../../docs/mkdocs/zh/tool.md) diff --git a/examples/todo_tool_with_human_in_the_loop/agent/__init__.py b/examples/todo_tool_with_human_in_the_loop/agent/__init__.py new file mode 100644 index 0000000..bc6e483 --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/agent/__init__.py @@ -0,0 +1,5 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. diff --git a/examples/todo_tool_with_human_in_the_loop/agent/agent.py b/examples/todo_tool_with_human_in_the_loop/agent/agent.py new file mode 100644 index 0000000..4ba680b --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/agent/agent.py @@ -0,0 +1,78 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent module for the TodoWriteTool + Human-in-the-Loop example.""" + +import os +from typing import List +from typing import Optional + +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import LLMModel +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import LongRunningFunctionTool +from trpc_agent_sdk.tools import TodoItem +from trpc_agent_sdk.tools import TodoStatus +from trpc_agent_sdk.tools import TodoWriteTool +from trpc_agent_sdk.tools.file_tools import BashTool +from trpc_agent_sdk.tools.file_tools import ReadTool +from trpc_agent_sdk.tools.file_tools import WriteTool + +from .config import get_model_config +from .prompts import INSTRUCTION +from .tools import request_todo_plan_approval + + +def _create_model() -> LLMModel: + """Create the LLM model used by the demo agent.""" + api_key, url, model_name = get_model_config() + return OpenAIModel(model_name=model_name, api_key=api_key, base_url=url) + + +def _all_done_nudge_hook(old: List[TodoItem], new: List[TodoItem]) -> Optional[str]: + """Example read-only NudgeHook (aligned with Go ``examples/todo``). + + When the plan has at least three items and every item is ``completed``, + append a reminder so the model summarises the outcome before wrapping up. + """ + if len(new) < 3: + return None + if not all(item.status == TodoStatus.COMPLETED for item in new): + return None + return ("Reminder: all tasks are marked completed. " + "Before finishing, briefly summarise the outcome for the user.") + + +def create_todo_agent(work_dir: str | None = None) -> LlmAgent: + """Build an agent with HITL todo planning plus file/shell execution tools. + + Args: + work_dir: Working directory for ``Bash`` / ``Write`` / ``Read``. Defaults to ``os.getcwd()``. + + ``clear_on_all_done=False`` keeps completed items visible so the demo + can render the final all-done checklist; production agents may keep + the default (``True``) to avoid stale items piling up across turns. + """ + cwd = work_dir or os.getcwd() + todo_tool = TodoWriteTool( + clear_on_all_done=False, + nudge_hooks=[_all_done_nudge_hook], + ) + plan_approval_tool = LongRunningFunctionTool(request_todo_plan_approval) + bash_tool = BashTool(cwd=cwd) + write_tool = WriteTool(cwd=cwd) + read_tool = ReadTool(cwd=cwd) + + return LlmAgent( + name="todo_planner", + description="Engineering assistant that plans and tracks multi-step tasks.", + model=_create_model(), + instruction=INSTRUCTION, + tools=[plan_approval_tool, todo_tool, bash_tool, write_tool, read_tool], + ) + + +todo_agent = create_todo_agent() +root_agent = todo_agent diff --git a/examples/todo_tool_with_human_in_the_loop/agent/config.py b/examples/todo_tool_with_human_in_the_loop/agent/config.py new file mode 100644 index 0000000..7f60e29 --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/agent/config.py @@ -0,0 +1,19 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +""" Agent config module""" + +import os + + +def get_model_config() -> tuple[str, str, str]: + """Get LLM model config from environment variables.""" + api_key = os.getenv('TRPC_AGENT_API_KEY', '') + url = os.getenv('TRPC_AGENT_BASE_URL', '') + model_name = os.getenv('TRPC_AGENT_MODEL_NAME', '') + if not api_key or not url or not model_name: + raise ValueError('''TRPC_AGENT_API_KEY, TRPC_AGENT_BASE_URL, + and TRPC_AGENT_MODEL_NAME must be set in environment variables''') + return api_key, url, model_name diff --git a/examples/todo_tool_with_human_in_the_loop/agent/prompts.py b/examples/todo_tool_with_human_in_the_loop/agent/prompts.py new file mode 100644 index 0000000..74b54da --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/agent/prompts.py @@ -0,0 +1,11 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompts for the TodoWrite + Human-in-the-Loop demo agent.""" + +# Demo-specific persona only. ``DEFAULT_TODO_PROMPT`` is injected automatically +# by ``TodoWriteTool.process_request`` when the tool is registered on the agent. +INSTRUCTION = ("You are a rigorous engineering assistant that breaks a task and works " + "through it step by step.\n") \ No newline at end of file diff --git a/examples/todo_tool_with_human_in_the_loop/agent/tools.py b/examples/todo_tool_with_human_in_the_loop/agent/tools.py new file mode 100644 index 0000000..29642e5 --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/agent/tools.py @@ -0,0 +1,54 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Human-in-the-loop tools layered on top of TodoWriteTool.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from pydantic import ValidationError + +from trpc_agent_sdk.tools import TodoItem +from trpc_agent_sdk.tools import render_todos +from trpc_agent_sdk.tools import validate_todos + + +async def request_todo_plan_approval(todos: list[dict[str, Any]], summary: str = "") -> dict: + """Request human approval before persisting a new todo plan. + + The agent should call this tool **before** the first ``todo_write`` when + laying out a multi-step plan. After the human approves (or edits) the + plan, call ``todo_write`` with the approved list. Subsequent status + updates can go directly to ``todo_write`` without another approval. + + Args: + todos: Proposed complete todo list (same shape as ``todo_write``). + summary: Short rationale for the plan shown to the reviewer. + + Returns: + A pending-approval payload consumed by ``LongRunningEvent`` handling. + """ + if not isinstance(todos, list): + return {"error": "INVALID_ARGS: `todos` must be an array"} + + try: + items = [TodoItem.model_validate(x) for x in todos] + except (ValidationError, TypeError) as exc: + return {"error": f"INVALID_ARGS: each todo must have content/activeForm/status: {exc}"} + + if err := validate_todos(items): + return {"error": f"INVALID_TODOS: {err}"} + + return { + "status": "pending_approval", + "message": summary or "New todo plan requires human approval before persisting.", + "todos": [t.model_dump(mode="json", by_alias=True) for t in items], + "preview": render_todos(items), + "approval_id": str(uuid.uuid4()), + "timestamp": time.time(), + } diff --git a/examples/todo_tool_with_human_in_the_loop/run_agent.py b/examples/todo_tool_with_human_in_the_loop/run_agent.py new file mode 100644 index 0000000..5d0026c --- /dev/null +++ b/examples/todo_tool_with_human_in_the_loop/run_agent.py @@ -0,0 +1,280 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Demo entry point for TodoWriteTool with Human-in-the-Loop approval. + +The script drives a multi-turn conversation in ONE session so it exercises: + +- ``request_todo_plan_approval`` (``LongRunningFunctionTool``) pauses the + agent until a human approves the initial plan. +- ``todo_write`` persists the checklist in session-level state across turns. +- ``Bash`` / ``Write`` / ``Read`` execute and verify file operations. + +Flow (4 turns): + +1. Turn 1 — scaffold a multi-file static site under ``demo/``; plan goes + through HITL approval first, then the agent executes step by step. +2. Turn 2 — add an interactive button + live clock; verify with ``Read``. +3. Turn 3 — rename the heading and expand ``README.md``. +4. Turn 4 — run ``ls`` / ``find`` sanity checks and confirm all todos done. + +After each turn the demo reads the PERSISTED list back from the session +with :func:`get_todos` and renders it as an ASCII checklist. + +Set TRPC_AGENT_API_KEY / TRPC_AGENT_BASE_URL / TRPC_AGENT_MODEL_NAME (see +.env) before running. +""" + +from __future__ import annotations + +import asyncio +import os +import shutil +import time +import uuid +from typing import Optional + +from dotenv import load_dotenv +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.events import LongRunningEvent +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.tools import get_todos +from trpc_agent_sdk.tools import render_todos +from trpc_agent_sdk.tools import TodoItem +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import FunctionResponse +from trpc_agent_sdk.types import Part + +load_dotenv() + +APP_NAME = "todo_hitl_demo" +USER_ID = "demo_user" + +TURNS = [ + ( + "搭建静态站点", + "请帮我在当前目录搭建一个 demo 静态站点,要求:\n" + "1) 创建 demo/ 及子目录 css/、js/\n" + "2) demo/index.html:title 和 h1 均为「Todo HITL Demo」,引入 css/style.css 与 js/app.js\n" + "3) demo/css/style.css:body 居中、浅灰背景、无衬线字体\n" + "4) demo/js/app.js:DOMContentLoaded 时在 console 打印「Todo HITL demo loaded」\n" + "请先提交完整计划等我审批;审批通过后逐步执行", + ) +] + + +def _summarise_tool_response(name: str, resp) -> tuple[str, str | None]: + """Compact tool responses; also return ``message`` for nudge visibility.""" + if not isinstance(resp, dict): + return str(resp), None + message = resp.get("message") if isinstance(resp.get("message"), str) else None + if "error" in resp: + return f"error={resp['error']!r}", None + match name: + case "Bash": + stdout = (resp.get("stdout") or "").strip() + stdout = stdout[:80] + ("..." if len(stdout) > 80 else "") + return f"success={resp.get('success')} rc={resp.get('return_code')} stdout={stdout!r}", None + case "Write": + return f"path={resp.get('path')!r} success={resp.get('success')}", message + case "Read": + return f"path={resp.get('path')!r} lines={resp.get('total_lines')}", None + case "todo_write": + old = resp.get("oldTodos") or [] + return f"items={len(resp.get('todos') or [])} old_items={len(old)}", message + case "request_todo_plan_approval": + todos = resp.get("todos") or [] + preview = resp.get("preview") or "" + return f"pending_approval items={len(todos)} preview={preview!r}", message + case _: + return str(resp), message + + +def _print_todo_checklist(resp: dict, *, indent: str = " ") -> None: + """Render todo_write response as a checklist (same format as end-of-turn summary).""" + raw = resp.get("todos") or [] + if not raw: + print(f"{indent}(empty)") + return + try: + items = [TodoItem.model_validate(t) for t in raw] + except Exception: + return + print("📋 Current checklist:") + for line in render_todos(items).splitlines(): + print(f"{indent}{line}") + + +def _print_event_parts(event, *, final_text_parts: list[str]) -> None: + """Pretty-print non-partial assistant / tool events.""" + if not event.content or not event.content.parts: + return + if event.partial: + return + for part in event.content.parts: + if part.thought: + continue + if part.function_call: + print(f"🔧 [Invoke Tool: {part.function_call.name}({part.function_call.args})]") + elif part.function_response: + name = part.function_response.name + resp = part.function_response.response + summary, message = _summarise_tool_response(name, resp) + print(f"📊 [Tool Result: {summary}]") + if message: + print(f"💬 [Tool Message: {message}]") + if name == "todo_write" and isinstance(resp, dict) and "error" not in resp: + _print_todo_checklist(resp) + elif part.text: + final_text_parts.append(part.text) + + +async def _consume_run( + runner: Runner, + *, + session_id: str, + content: Content, +) -> tuple[list[str], Optional[LongRunningEvent]]: + """Run one ``runner.run_async`` invocation; capture text and HITL events.""" + final_text_parts: list[str] = [] + long_running_event: Optional[LongRunningEvent] = None + + async for event in runner.run_async( + user_id=USER_ID, + session_id=session_id, + new_message=content, + ): + if isinstance(event, LongRunningEvent): + long_running_event = event + resp = event.function_response.response + print("\n🔄 [Long-running operation detected — waiting for human approval]") + print(f" Function: {event.function_call.name}") + print(f" Args: {event.function_call.args}") + if isinstance(resp, dict) and resp.get("preview"): + print(" Proposed checklist:") + for line in str(resp["preview"]).splitlines(): + print(f" {line}") + continue + _print_event_parts(event, final_text_parts=final_text_parts) + + return final_text_parts, long_running_event + + +def _build_approval_resume(long_running_event: LongRunningEvent) -> Content: + """Simulate human approval with a plan edit: inject an extra todo.""" + response_data = dict(long_running_event.function_response.response) + if response_data.get("status") != "pending_approval": + raise ValueError(f"Expected pending_approval, got {response_data.get('status')!r}") + + todos = list(response_data.get("todos") or []) + todos.append({ + "content": "生成 README 文件", + "activeForm": "正在生成 README 文件", + "status": "pending", + }) + response_data["todos"] = todos + response_data["preview"] = render_todos([TodoItem.model_validate(t) for t in todos]) + + response_data["status"] = "approved" + response_data["message"] = ( + "APPROVED with edit: added todo「生成 README 文件」— proceed with todo_write using the updated list." + ) + response_data["approved_by"] = USER_ID + response_data["timestamp"] = time.time() + + print("\n👤 [Human approval with plan edit]") + print(f" Decision: approved by {response_data['approved_by']}") + print(" Edit: added todo → 生成 README 文件") + print(" Updated checklist:") + for line in str(response_data["preview"]).splitlines(): + print(f" {line}") + + resume_function_response = FunctionResponse( + id=long_running_event.function_response.id, + name=long_running_event.function_response.name, + response=response_data, + ) + return Content(role="user", parts=[Part(function_response=resume_function_response)]) + + +async def _run_turn( + runner: Runner, + agent: LlmAgent, + *, + session_id: str, + label: str, + query: str, +) -> None: + """Drive a user turn, including optional HITL resume within the same turn.""" + print(f"\n========== {label} ==========") + print(f"📝 User: {query}") + + user_content = Content(parts=[Part.from_text(text=query)]) + final_text_parts, long_running_event = await _consume_run( + runner, + session_id=session_id, + content=user_content, + ) + + if long_running_event: + resume_content = _build_approval_resume(long_running_event) + print("\n🔄 Resuming agent after human approval...") + resume_text_parts, _ = await _consume_run( + runner, + session_id=session_id, + content=resume_content, + ) + final_text_parts.extend(resume_text_parts) + + if final_text_parts: + print(f"🤖 Assistant: {''.join(final_text_parts)}") + + session = await runner.session_service.get_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + ) + todos = get_todos(session, branch=agent.name) or get_todos(session, branch="") + print("\n📋 Persisted checklist:") + print(render_todos(todos) if todos else " (empty)") + print("-" * 40) + + +async def main() -> None: + from agent.agent import create_todo_agent + + work_dir = os.getcwd() + demo_dir = os.path.join(work_dir, "demo") + if os.path.isdir(demo_dir): + shutil.rmtree(demo_dir) + print(f"🧹 Cleaned previous {demo_dir}") + + todo_agent = create_todo_agent(work_dir=work_dir) + + runner = Runner( + app_name=APP_NAME, + agent=todo_agent, + session_service=InMemorySessionService(), + ) + + session_id = str(uuid.uuid4()) + await runner.session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + state={"user_name": USER_ID}, + ) + print(f"🆔 Session ID: {session_id[:8]}... (shared across all turns)") + print(f"📂 Work dir: {work_dir}") + + for label, query in TURNS: + await _run_turn(runner, todo_agent, session_id=session_id, label=label, query=query) + + await runner.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/tools/task_tools/test_task_tools.py b/tests/tools/task_tools/test_task_tools.py new file mode 100644 index 0000000..51f8c15 --- /dev/null +++ b/tests/tools/task_tools/test_task_tools.py @@ -0,0 +1,327 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tests for the Task tool family (:mod:`trpc_agent_sdk.tools.task_tools`).""" + +from __future__ import annotations + +import asyncio + +import pytest + +from trpc_agent_sdk.agents._base_agent import BaseAgent +from trpc_agent_sdk.context import InvocationContext, create_agent_context +from trpc_agent_sdk.events import Event +from trpc_agent_sdk.models import LlmRequest +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.tools.task_tools import ( + DEFAULT_STATE_KEY_PREFIX, + TaskCreateTool, + TaskGetTool, + TaskListTool, + TaskStatus, + TaskStore, + TaskToolSet, + TaskUpdateTool, + decode_store, + detect_cycle, + get_task_store, + render_task_list, + state_key, +) +from trpc_agent_sdk.tools.task_tools._models import TaskRecord +from trpc_agent_sdk.types import Content, EventActions, Part + + +class _StubAgent(BaseAgent): + async def _run_async_impl(self, ctx): + yield + + +@pytest.fixture +def session_bundle(): + service = InMemorySessionService() + session = asyncio.run( + service.create_session(app_name="test", user_id="u1", session_id="s1") + ) + agent = _StubAgent(name="task_planner") + ctx = InvocationContext( + session_service=service, + invocation_id="inv-1", + agent=agent, + agent_context=create_agent_context(), + session=session, + branch="", + ) + return service, session, agent, ctx + + +async def _create(ctx, subject, **kwargs): + tool = TaskCreateTool() + return await tool._run_async_impl(tool_context=ctx, args={"subject": subject, **kwargs}) + + +async def _update(ctx, task_id, **kwargs): + tool = TaskUpdateTool() + return await tool._run_async_impl(tool_context=ctx, args={"taskId": task_id, **kwargs}) + + +class TestTaskCreate: + @pytest.mark.asyncio + async def test_assigns_incrementing_id(self, session_bundle): + _, _, agent, ctx = session_bundle + first = await _create(ctx, "A") + second = await _create(ctx, "B") + assert first["task"]["id"] == "1" + assert second["task"]["id"] == "2" + assert state_key(DEFAULT_STATE_KEY_PREFIX, agent.name) in ctx.state._delta + + @pytest.mark.asyncio + async def test_rejects_empty_subject(self, session_bundle): + _, _, _, ctx = session_bundle + result = await _create(ctx, " ") + assert "error" in result + + @pytest.mark.asyncio + async def test_id_not_reused_after_delete(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + await _update(ctx, "1", status="deleted") + third = await _create(ctx, "C") + assert third["task"]["id"] == "2" + + +class TestTaskUpdate: + @pytest.mark.asyncio + async def test_status_transition(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + res = await _update(ctx, "1", status="in_progress") + assert res["task"]["status"] == "in_progress" + res = await _update(ctx, "1", status="completed") + assert res["task"]["status"] == "completed" + + @pytest.mark.asyncio + async def test_not_found(self, session_bundle): + _, _, _, ctx = session_bundle + res = await _update(ctx, "99", status="completed") + assert "NOT_FOUND" in res["error"] + + @pytest.mark.asyncio + async def test_single_in_progress_enforced(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + await _create(ctx, "B") + await _update(ctx, "1", status="in_progress") + res = await _update(ctx, "2", status="in_progress") + assert "error" in res + assert "in_progress" in res["error"] + + @pytest.mark.asyncio + async def test_single_in_progress_can_be_disabled(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + await _create(ctx, "B") + tool = TaskUpdateTool(enforce_single_in_progress=False) + await tool._run_async_impl(tool_context=ctx, args={"taskId": "1", "status": "in_progress"}) + res = await tool._run_async_impl(tool_context=ctx, args={"taskId": "2", "status": "in_progress"}) + assert "error" not in res + + @pytest.mark.asyncio + async def test_two_way_dependency_edges(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "schema") + await _create(ctx, "endpoints") + res = await _update(ctx, "2", addBlockedBy=["1"]) + assert res["task"]["blockedBy"] == ["1"] + store = get_task_store(ctx.session, branch="task_planner") + assert store.tasks["1"].blocks == ["2"] + + @pytest.mark.asyncio + async def test_complete_unblocks_downstream(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "schema") + await _create(ctx, "endpoints") + await _update(ctx, "2", addBlockedBy=["1"]) + res = await _update(ctx, "1", status="completed") + assert res["unblocked"] == ["2"] + store = get_task_store(ctx.session, branch="task_planner") + assert store.tasks["2"].blocked_by == [] + + @pytest.mark.asyncio + async def test_cycle_rejected(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + await _create(ctx, "B") + await _update(ctx, "2", addBlockedBy=["1"]) + res = await _update(ctx, "1", addBlockedBy=["2"]) + assert "INVALID_DEPENDENCY" in res["error"] + + @pytest.mark.asyncio + async def test_missing_dependency_rejected(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + res = await _update(ctx, "1", addBlockedBy=["99"]) + assert "INVALID_DEPENDENCY" in res["error"] + + @pytest.mark.asyncio + async def test_deleted_cannot_be_modified(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + await _update(ctx, "1", status="deleted") + res = await _update(ctx, "1", status="in_progress") + assert "error" in res + + +class TestTaskGetAndList: + @pytest.mark.asyncio + async def test_get_includes_description(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A", description="long detail") + res = await TaskGetTool()._run_async_impl(tool_context=ctx, args={"taskId": "1"}) + assert res["task"]["description"] == "long detail" + + @pytest.mark.asyncio + async def test_get_not_found(self, session_bundle): + _, _, _, ctx = session_bundle + res = await TaskGetTool()._run_async_impl(tool_context=ctx, args={"taskId": "1"}) + assert "NOT_FOUND" in res["error"] + + @pytest.mark.asyncio + async def test_list_omits_description_and_filters_deleted(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A", description="should not appear") + await _create(ctx, "B") + await _update(ctx, "2", status="deleted") + res = await TaskListTool()._run_async_impl(tool_context=ctx, args={}) + assert len(res["tasks"]) == 1 + assert "description" not in res["tasks"][0] + assert res["stats"]["pending"] == 1 + + @pytest.mark.asyncio + async def test_list_include_deleted(self, session_bundle): + _, _, _, ctx = session_bundle + await _create(ctx, "A") + await _update(ctx, "1", status="deleted") + res = await TaskListTool()._run_async_impl(tool_context=ctx, args={"includeDeleted": True}) + assert len(res["tasks"]) == 1 + + +class TestBranchIsolation: + @pytest.mark.asyncio + async def test_branches_are_independent(self, session_bundle): + service, session, agent, _ = session_bundle + ctx_a = InvocationContext( + session_service=service, invocation_id="i", agent=agent, + agent_context=create_agent_context(), session=session, branch="a", + ) + ctx_b = InvocationContext( + session_service=service, invocation_id="i", agent=agent, + agent_context=create_agent_context(), session=session, branch="b", + ) + await _create(ctx_a, "task in a") + await _create(ctx_b, "task in b") + store_a = decode_store(ctx_a.state._delta[state_key(DEFAULT_STATE_KEY_PREFIX, "a")]) + store_b = decode_store(ctx_b.state._delta[state_key(DEFAULT_STATE_KEY_PREFIX, "b")]) + assert store_a.tasks["1"].subject == "task in a" + assert store_b.tasks["1"].subject == "task in b" + + +class TestPersistence: + @pytest.mark.asyncio + async def test_store_survives_append_event_and_get_session(self, session_bundle): + service, session, agent, ctx = session_bundle + await _create(ctx, "A") + await _update(ctx, "1", status="in_progress") + + event = Event( + invocation_id="inv-1", + author=agent.name, + content=Content(parts=[Part.from_text(text="tool result")]), + actions=EventActions(state_delta=dict(ctx.event_actions.state_delta)), + ) + await service.append_event(session, event) + + stored = await service.get_session(app_name="test", user_id="u1", session_id="s1") + store = get_task_store(stored, branch=agent.name) + assert store.tasks["1"].status == TaskStatus.IN_PROGRESS + + +class TestConcurrency: + @pytest.mark.asyncio + async def test_parallel_create_assigns_unique_ids(self, session_bundle): + _, _, _, ctx = session_bundle + tool = TaskCreateTool() + results = await asyncio.gather( + *[ + tool._run_async_impl(tool_context=ctx, args={"subject": f"task-{i}"}) + for i in range(20) + ] + ) + ids = sorted(int(r["task"]["id"]) for r in results) + assert ids == list(range(1, 21)) + + @pytest.mark.asyncio + async def test_parallel_mixed_ops_preserve_all_tasks(self, session_bundle): + _, _, _, ctx = session_bundle + create = TaskCreateTool() + update = TaskUpdateTool() + await create._run_async_impl(tool_context=ctx, args={"subject": "seed"}) + await asyncio.gather( + create._run_async_impl(tool_context=ctx, args={"subject": "A"}), + create._run_async_impl(tool_context=ctx, args={"subject": "B"}), + update._run_async_impl(tool_context=ctx, args={"taskId": "1", "status": "in_progress"}), + ) + store = get_task_store(ctx.session, branch="task_planner") + assert set(store.tasks) == {"1", "2", "3"} + assert store.tasks["1"].status == TaskStatus.IN_PROGRESS + + +class TestHelpers: + def test_state_key(self): + assert state_key(DEFAULT_STATE_KEY_PREFIX, "") == "tasks" + assert state_key(DEFAULT_STATE_KEY_PREFIX, "planner") == "tasks:planner" + + def test_decode_dirty_data_degrades_to_empty(self): + assert decode_store("not json").tasks == {} + assert decode_store(None).tasks == {} + + def test_detect_cycle_on_clean_store(self): + store = TaskStore() + store.tasks["1"] = TaskRecord(id="1", subject="A") + store.tasks["2"] = TaskRecord(id="2", subject="B", blockedBy=["1"]) + assert detect_cycle(store) is None + + def test_render_task_list(self): + store = TaskStore() + store.tasks["1"] = TaskRecord(id="1", subject="Done", status=TaskStatus.COMPLETED) + store.tasks["2"] = TaskRecord( + id="2", subject="Active", activeForm="Doing active", status=TaskStatus.IN_PROGRESS + ) + store.tasks["3"] = TaskRecord(id="3", subject="Wait", blockedBy=["2"]) + rendered = render_task_list(store) + assert "✅ #1 Done" in rendered + assert "🔄 #2 Doing active" in rendered + assert "blocked by: 2" in rendered + +class TestProcessRequest: + @pytest.mark.asyncio + async def test_injects_prompt_once(self, session_bundle): + _, _, _, ctx = session_bundle + llm_request = LlmRequest() + await TaskCreateTool().process_request(tool_context=ctx, llm_request=llm_request) + await TaskUpdateTool().process_request(tool_context=ctx, llm_request=llm_request) + text = str(llm_request.config.system_instruction) + assert text.count("structured task board via the tools") == 1 + assert "task_create" in llm_request.tools_dict + assert "task_update" in llm_request.tools_dict + + +class TestTaskToolSet: + @pytest.mark.asyncio + async def test_returns_four_tools(self): + tools = await TaskToolSet().get_tools() + names = {t.name for t in tools} + assert names == {"task_create", "task_update", "task_get", "task_list"} diff --git a/tests/tools/test_todo_tool.py b/tests/tools/test_todo_tool.py new file mode 100644 index 0000000..5b30de8 --- /dev/null +++ b/tests/tools/test_todo_tool.py @@ -0,0 +1,211 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tests for :mod:`trpc_agent_sdk.tools._todo_tool`.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from trpc_agent_sdk.agents._base_agent import BaseAgent +from trpc_agent_sdk.context import InvocationContext, create_agent_context +from trpc_agent_sdk.events import Event +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.models import LlmRequest +from trpc_agent_sdk.tools._todo_tool import ( + DEFAULT_STATE_KEY_PREFIX, + TodoItem, + TodoStatus, + TodoWriteTool, + get_todos, + render_todos, + state_key, + validate_todos, +) +from trpc_agent_sdk.types import Content, EventActions, Part, State + + +class _StubAgent(BaseAgent): + async def _run_async_impl(self, ctx): + yield + + +def _sample_todos(**statuses: str) -> list[dict]: + defaults = { + "step1": ("Run step 1", "Running step 1", TodoStatus.IN_PROGRESS), + "step2": ("Run step 2", "Running step 2", TodoStatus.PENDING), + } + return [{ + "content": defaults[name][0], + "activeForm": defaults[name][1], + "status": statuses.get(name, defaults[name][2].value), + } for name in ("step1", "step2")] + + +@pytest.fixture +def session_bundle(): + service = InMemorySessionService() + session = asyncio.run( + service.create_session(app_name="test", user_id="u1", session_id="s1") + ) + agent = _StubAgent(name="todo_planner") + ctx = InvocationContext( + session_service=service, + invocation_id="inv-1", + agent=agent, + agent_context=create_agent_context(), + session=session, + branch="", + ) + return service, session, agent, ctx + + +class TestValidateTodos: + def test_accepts_valid_list(self): + items = [ + TodoItem(content="A", activeForm="Doing A", status=TodoStatus.IN_PROGRESS), + TodoItem(content="B", activeForm="Doing B", status=TodoStatus.PENDING), + ] + assert validate_todos(items) is None + + def test_rejects_multiple_in_progress(self): + items = [ + TodoItem(content="A", activeForm="Doing A", status=TodoStatus.IN_PROGRESS), + TodoItem(content="B", activeForm="Doing B", status=TodoStatus.IN_PROGRESS), + ] + assert "in_progress" in (validate_todos(items) or "") + + def test_rejects_duplicate_content(self): + items = [ + TodoItem(content="Same", activeForm="Doing same", status=TodoStatus.IN_PROGRESS), + TodoItem(content="Same", activeForm="Still same", status=TodoStatus.PENDING), + ] + assert "duplicates" in (validate_todos(items) or "") + + +class TestStateKey: + def test_default_prefix_without_branch(self): + assert state_key(DEFAULT_STATE_KEY_PREFIX, "") == "todos" + + def test_appends_branch(self): + assert state_key(DEFAULT_STATE_KEY_PREFIX, "todo_planner") == "todos:todo_planner" + + +class TestProcessRequest: + @pytest.mark.asyncio + async def test_process_request_adds_instructions(self, session_bundle): + _, _, _, ctx = session_bundle + tool = TodoWriteTool() + llm_request = LlmRequest() + + await tool.process_request(tool_context=ctx, llm_request=llm_request) + + assert tool.name in llm_request.tools_dict + assert llm_request.config is not None + assert llm_request.config.system_instruction is not None + assert "todo_write" in str(llm_request.config.system_instruction).lower() + + +class TestTodoWriteTool: + @pytest.mark.asyncio + async def test_writes_and_returns_echo(self, session_bundle): + _, _, agent, ctx = session_bundle + tool = TodoWriteTool(clear_on_all_done=False) + payload = _sample_todos() + + result = await tool._run_async_impl(tool_context=ctx, args={"todos": payload}) + + assert "error" not in result + assert len(result["todos"]) == 2 + assert result["oldTodos"] is None + key = state_key(DEFAULT_STATE_KEY_PREFIX, agent.name) + assert key in ctx.state._delta + + @pytest.mark.asyncio + async def test_reads_previous_list_on_second_call(self, session_bundle): + _, _, agent, ctx = session_bundle + tool = TodoWriteTool(clear_on_all_done=False) + first = await tool._run_async_impl(tool_context=ctx, args={"todos": _sample_todos()}) + assert first["oldTodos"] is None + + second = await tool._run_async_impl( + tool_context=ctx, + args={"todos": _sample_todos(step1="completed", step2="in_progress")}, + ) + assert len(second["oldTodos"]) == 2 + assert second["todos"][0]["status"] == "completed" + assert second["todos"][1]["status"] == "in_progress" + assert state_key(DEFAULT_STATE_KEY_PREFIX, agent.name) in ctx.session.state + + @pytest.mark.asyncio + async def test_rejects_missing_todos_field(self, session_bundle): + _, _, _, ctx = session_bundle + result = await TodoWriteTool()._run_async_impl(tool_context=ctx, args={}) + assert "error" in result + + @pytest.mark.asyncio + async def test_clear_on_all_done(self, session_bundle): + _, _, _, ctx = session_bundle + tool = TodoWriteTool(clear_on_all_done=True) + result = await tool._run_async_impl( + tool_context=ctx, + args={"todos": _sample_todos(step1="completed", step2="completed")}, + ) + assert result["todos"] == [] + + +class TestPersistence: + @pytest.mark.asyncio + async def test_todos_prefix_survives_append_event_and_get_session(self, session_bundle): + service, session, agent, ctx = session_bundle + tool = TodoWriteTool(clear_on_all_done=False) + await tool._run_async_impl(tool_context=ctx, args={"todos": _sample_todos()}) + + event = Event( + invocation_id="inv-1", + author=agent.name, + content=Content(parts=[Part.from_text(text="tool result")]), + actions=EventActions(state_delta=dict(ctx.event_actions.state_delta)), + ) + await service.append_event(session, event) + + stored = await service.get_session(app_name="test", user_id="u1", session_id="s1") + todos = get_todos(stored, branch=agent.name) + assert len(todos) == 2 + assert todos[0].status == TodoStatus.IN_PROGRESS + + @pytest.mark.asyncio + async def test_temp_prefix_is_not_persisted(self, session_bundle): + service, session, agent, ctx = session_bundle + tool = TodoWriteTool(state_key_prefix="temp:todos", clear_on_all_done=False) + await tool._run_async_impl(tool_context=ctx, args={"todos": _sample_todos()}) + + event = Event( + invocation_id="inv-1", + author=agent.name, + content=Content(parts=[Part.from_text(text="tool result")]), + actions=EventActions(state_delta=dict(ctx.event_actions.state_delta)), + ) + await service.append_event(session, event) + + stored = await service.get_session(app_name="test", user_id="u1", session_id="s1") + key = state_key("temp:todos", agent.name) + assert key not in (stored.state or {}) + assert get_todos(stored, branch=agent.name, prefix="temp:todos") == [] + + +class TestRenderTodos: + def test_renders_checklist(self): + items = [ + TodoItem(content="Done task", activeForm="Doing done", status=TodoStatus.COMPLETED), + TodoItem(content="Active task", activeForm="Doing active", status=TodoStatus.IN_PROGRESS), + TodoItem(content="Pending task", activeForm="Doing pending", status=TodoStatus.PENDING), + ] + rendered = render_todos(items) + assert "✅ Done task" in rendered + assert "🔄 Doing active" in rendered + assert "⬜ Pending task" in rendered diff --git a/trpc_agent_sdk/tools/__init__.py b/trpc_agent_sdk/tools/__init__.py index 1ced02e..3efb355 100644 --- a/trpc_agent_sdk/tools/__init__.py +++ b/trpc_agent_sdk/tools/__init__.py @@ -34,9 +34,32 @@ from ._set_model_response_tool import SetModelResponseTool from ._streaming_function_tool import StreamingFunctionTool from ._streaming_progress_tool import StreamingProgressTool +from ._todo_tool import DEFAULT_NUDGE_MESSAGE +from ._todo_tool import DEFAULT_STATE_KEY_PREFIX +from ._todo_tool import DEFAULT_TODO_DESCRIPTION +from ._todo_tool import DEFAULT_TODO_PROMPT +from ._todo_tool import TodoItem +from ._todo_tool import TodoStatus +from ._todo_tool import TodoWriteTool +from ._todo_tool import get_todos +from ._todo_tool import render_todos +from ._todo_tool import state_key +from ._todo_tool import validate_todos from ._tool_adapter import convert_toolunion_to_tool_list from ._tool_adapter import create_tool from ._tool_adapter import create_toolset +from .task_tools import DEFAULT_TASK_PROMPT +from .task_tools import TaskCreateTool +from .task_tools import TaskGetTool +from .task_tools import TaskListSummary +from .task_tools import TaskListTool +from .task_tools import TaskRecord +from .task_tools import TaskStatus +from .task_tools import TaskStore +from .task_tools import TaskToolSet +from .task_tools import TaskUpdateTool +from .task_tools import get_task_store +from .task_tools import render_task_list from ._transfer_to_agent_tool import transfer_to_agent from ._webfetch_tool import FetchResult from ._webfetch_tool import WebFetchTool @@ -97,6 +120,29 @@ "create_tool", "create_toolset", "transfer_to_agent", + "TodoWriteTool", + "TodoItem", + "TodoStatus", + "get_todos", + "state_key", + "render_todos", + "validate_todos", + "DEFAULT_TODO_PROMPT", + "DEFAULT_TODO_DESCRIPTION", + "DEFAULT_NUDGE_MESSAGE", + "DEFAULT_STATE_KEY_PREFIX", + "TaskCreateTool", + "TaskUpdateTool", + "TaskGetTool", + "TaskListTool", + "TaskToolSet", + "TaskStatus", + "TaskRecord", + "TaskStore", + "TaskListSummary", + "get_task_store", + "render_task_list", + "DEFAULT_TASK_PROMPT", "FetchResult", "WebFetchTool", "SearchHit", diff --git a/trpc_agent_sdk/tools/_todo_tool.py b/trpc_agent_sdk/tools/_todo_tool.py new file mode 100644 index 0000000..50386a0 --- /dev/null +++ b/trpc_agent_sdk/tools/_todo_tool.py @@ -0,0 +1,378 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""TodoWrite tool for TRPC Agent framework. + +Provides a client-side :class:`TodoWriteTool` that lets an LLM agent +plan, track and report multi-step tasks via a single structured +checklist. The tool follows the Claude Code / DeepAgents "whole-list +replacement" model: + +- The model sends the **complete, updated** todo list on every call; the + new list entirely replaces the previous one (no smart merge). +- The list is persisted in **session-level state** (key prefix ``todos``, + no ``temp:`` prefix) so it survives across ``Runner.run_async`` + invocations. Keys with the ``temp:`` prefix are stripped by + :class:`~trpc_agent_sdk.sessions.BaseSessionService` and never land in + storage. +- Different branches (parent / sub agents) keep **independent** lists. + +The tool enforces a small set of *hard contracts* (well-formed input, +at most one ``in_progress`` item, unique ``content``) in code; softer +style guidance lives in :data:`DEFAULT_TODO_PROMPT`, appended to the +system instruction automatically via :meth:`TodoWriteTool.process_request`. + +The function-response payload echoes ``{message, todos, oldTodos}`` so a +front-end / CLI can render the current list and a diff without re-reading +session state. +""" + +from __future__ import annotations + +import json +from enum import Enum +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing_extensions import override + +from pydantic import BaseModel +from pydantic import Field +from pydantic import ValidationError + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.filter import BaseFilter +from trpc_agent_sdk.log import logger +from trpc_agent_sdk.models import LlmRequest +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type + +from ._base_tool import BaseTool + +# Default tool name. ``snake_case`` satisfies the strict +# ``^[a-zA-Z0-9_-]+$`` function-name constraint some providers enforce. +_DEFAULT_TOOL_NAME = "todo_write" +# Default state-key prefix. Session-scoped (no ``temp:``) so +# BaseSessionService persists the list across Runner invocations. +DEFAULT_STATE_KEY_PREFIX = "todos" +# Default nudge appended to every successful response to keep the model +# anchored on the plan. +DEFAULT_NUDGE_MESSAGE = ("Todo list updated. Keep exactly one item in_progress, mark items completed " + "the moment they are done, and call todo_write again whenever the plan changes.") + +# Short description fed to the model as part of the function schema. +DEFAULT_TODO_DESCRIPTION = """\ +Create and manage a structured todo list for the current task. +- Call this to plan a multi-step task up front, then to flip each item's status as you make progress. +- Send the COMPLETE, updated list every time; it replaces the previous list entirely. +- Keep AT MOST ONE item `in_progress` at any moment; mark an item `completed` as soon as it is finished. +- The only way to clear the list is to send an explicit empty array (`todos: []`). + +USE WHEN: + - the task has 3+ distinct steps, or is non-trivial and benefits from explicit planning + - the user provides multiple tasks, or you discover follow-up work mid-task + +DO NOT USE WHEN: + - the task is a single trivial step, or is purely conversational/informational\ +""" + +# Long-form guidance appended via :meth:`TodoWriteTool.process_request`. +# Kept separate from the hard contract enforced by :func:`validate_todos`. +# Still exported for tests and callers who need the raw text. +DEFAULT_TODO_PROMPT = """\ +You have access to the `todo_write` tool to plan and track multi-step work. + +When to use it: + - Use it for any task with 3+ distinct steps, multi-file changes, or non-trivial work that benefits + from explicit planning. Plan first, then execute item by item. + - When the user gives multiple requests, capture each as a todo item before starting. + - Skip it for single trivial steps or purely informational questions. + +How to use it: + - Always send the COMPLETE, updated list. The new list replaces the old one entirely. + - Each item has `content` (imperative, e.g. "Run tests"), `activeForm` (present-continuous, e.g. + "Running tests"), and `status` (one of `pending`, `in_progress`, `completed`). + - Keep exactly one item `in_progress` while you work on it; never leave a stale `in_progress`. + - Mark an item `completed` the moment it is done — do not batch completions at the end. + - To start a new step, set the previous one to `completed` and the next one to `in_progress` in the + same call. + - To clear the list, send `todos: []` (an explicit empty array). + +After calling `todo_write`, do not repeat the whole list back to the user — just continue the work and +summarise meaningful changes. +""" + +# Optional policy callback. Invoked after persistence, before returning; +# any non-empty string it returns is appended to the response message. +# Hooks are read-only and MUST NOT mutate the lists. +NudgeHook = Callable[[List["TodoItem"], List["TodoItem"]], Optional[str]] + + +class TodoStatus(str, Enum): + """Lifecycle state of a single todo item.""" + + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + + +class TodoItem(BaseModel): + """A single todo entry. + + ``active_form`` is exposed to the model / persisted under the + camelCase alias ``activeForm`` to stay compatible with the Go + implementation and Claude Code's schema. + """ + + content: str = Field(description="Imperative description of the task, e.g. 'Run tests'.") + active_form: str = Field( + alias="activeForm", + description="Present-continuous form shown while active, e.g. 'Running tests'.", + ) + status: TodoStatus = Field(description="One of: pending, in_progress, completed.") + + model_config = {"populate_by_name": True} + + +def validate_todos(todos: List[TodoItem]) -> Optional[str]: + """Enforce the hard contract on a todo list. + + Returns an error string when the list is invalid, or ``None`` when it + passes. Rules: + + - ``content`` and ``activeForm`` must be non-empty. + - At most one item may be ``in_progress``. + - ``content`` must be unique across the list (exact match; no trim / + case folding, to avoid silently merging items). + """ + in_progress = 0 + seen: dict[str, int] = {} + for i, item in enumerate(todos): + if not item.content or not item.content.strip(): + return f"todos[{i}].content must not be empty" + if not item.active_form or not item.active_form.strip(): + return f"todos[{i}].activeForm must not be empty" + if item.status == TodoStatus.IN_PROGRESS: + in_progress += 1 + if in_progress > 1: + return f"at most one item may be in_progress (second one at todos[{i}])" + if item.content in seen: + return f"todos[{i}].content {item.content!r} duplicates todos[{seen[item.content]}]" + seen[item.content] = i + return None + + +def state_key(prefix: str, branch: str) -> str: + """Build the state key, appending ``:`` for sub-agent isolation.""" + prefix = prefix or DEFAULT_STATE_KEY_PREFIX + return prefix if not branch else f"{prefix}:{branch}" + + +def _decode_todos(raw: Any) -> List[TodoItem]: + """Decode a persisted value (JSON string or list) into ``TodoItem``s. + + Tolerates dirty / legacy data: anything that fails to parse is + treated as an empty list rather than raising. + """ + if not raw: + return [] + try: + data = json.loads(raw) if isinstance(raw, str) else raw + if not isinstance(data, list): + return [] + return [TodoItem.model_validate(x) for x in data] + except (ValueError, ValidationError, TypeError) as e: + logger.warning("TodoWriteTool failed to decode persisted todos: %s", e) + return [] + + +def get_todos( + session: Any, + branch: str = "", + prefix: str = DEFAULT_STATE_KEY_PREFIX, +) -> List[TodoItem]: + """Read the current todo list for ``branch`` from a session. + + Intended for server-side / REST / audit reads. ``session`` only needs + a ``state`` mapping attribute. Malformed data degrades to ``[]``. + """ + state = getattr(session, "state", None) or {} + return _decode_todos(state.get(state_key(prefix, branch))) + + +def render_todos(todos: List[TodoItem]) -> str: + """Render a plain-text checklist (``✅`` / ``🔄`` / ``⬜``). + + Convenience for CLIs / logs; the tool itself never calls this. + """ + glyph = { + TodoStatus.COMPLETED: "✅", + TodoStatus.IN_PROGRESS: "🔄", + TodoStatus.PENDING: "⬜", + } + lines = [] + for item in todos: + mark = glyph.get(item.status, "⬜") + text = item.active_form if item.status == TodoStatus.IN_PROGRESS else item.content + lines.append(f"{mark} {text}") + return "\n".join(lines) + + +class TodoWriteTool(BaseTool): + """LLM tool that maintains a structured, persistent todo checklist. + + The model sends the complete updated list on each call; the tool + validates it, persists it to branch-scoped session state, and returns + a nudge plus the new and previous lists for downstream rendering. + + Args: + state_key_prefix: State-key prefix; ``todos`` by default. Avoid + ``temp:`` — that prefix is invocation-only and is not stored. + clear_on_all_done: When every item is ``completed``, store an empty + list instead so finished items do not pile up across turns + (default ``True``). + default_nudge: Base message appended to every successful response. + nudge_hooks: Optional read-only policy callbacks; each returned + non-empty string is appended to the response message. + filters_name / filters: forwarded to :class:`BaseTool`. + """ + + def __init__( + self, + *, + state_key_prefix: str = DEFAULT_STATE_KEY_PREFIX, + clear_on_all_done: bool = True, + default_nudge: str = DEFAULT_NUDGE_MESSAGE, + nudge_hooks: Optional[List[NudgeHook]] = None, + filters_name: Optional[List[str]] = None, + filters: Optional[List[BaseFilter]] = None, + ) -> None: + super().__init__( + name=_DEFAULT_TOOL_NAME, + description=DEFAULT_TODO_DESCRIPTION, + filters_name=filters_name, + filters=filters, + ) + self._prefix = state_key_prefix or DEFAULT_STATE_KEY_PREFIX + self._clear_on_all_done = bool(clear_on_all_done) + self._default_nudge = default_nudge + self._nudge_hooks = list(nudge_hooks or []) + + @override + def _get_declaration(self) -> FunctionDeclaration: + item_schema = Schema( + type=Type.OBJECT, + properties={ + "content": + Schema( + type=Type.STRING, + description="Imperative description of the task, e.g. 'Run tests'.", + ), + "activeForm": + Schema( + type=Type.STRING, + description="Present-continuous form shown while active, e.g. 'Running tests'.", + ), + "status": + Schema( + type=Type.STRING, + enum=["pending", "in_progress", "completed"], + description="Item status. Keep at most one item 'in_progress'.", + ), + }, + required=["content", "activeForm", "status"], + ) + return FunctionDeclaration( + name=self.name, + description=DEFAULT_TODO_DESCRIPTION, + parameters=Schema( + type=Type.OBJECT, + properties={ + "todos": + Schema( + type=Type.ARRAY, + items=item_schema, + description=("The complete, updated todo list. Replaces the previous list " + "entirely. Send [] to clear the list."), + ), + }, + required=["todos"], + ), + ) + + @override + async def process_request( + self, + *, + tool_context: InvocationContext, + llm_request: LlmRequest, + ) -> None: + """Register the declaration and inject behavioural guidance.""" + await super().process_request(tool_context=tool_context, llm_request=llm_request) + llm_request.append_instructions([DEFAULT_TODO_PROMPT]) + + @override + async def _run_async_impl( + self, + *, + tool_context: InvocationContext, + args: dict[str, Any], + ) -> Any: + # 1. decode-guard: a missing field or explicit null is rejected so + # a dropped key cannot silently wipe the plan. The only valid + # clear gesture is an explicit empty array. + if "todos" not in args: + return {"error": "INVALID_ARGS: `todos` is required and must be an array (use [] to clear)"} + raw_todos = args["todos"] + if raw_todos is None: + return {"error": "INVALID_ARGS: `todos` must be an array, got null (use [] to clear)"} + if not isinstance(raw_todos, list): + return {"error": "INVALID_ARGS: `todos` must be an array"} + + # 2. parse + hard-contract validation + try: + todos = [TodoItem.model_validate(x) for x in raw_todos] + except (ValidationError, TypeError) as e: + return {"error": f"INVALID_ARGS: each todo must have content/activeForm/status: {e}"} + if (err := validate_todos(todos)) is not None: + return {"error": f"INVALID_TODOS: {err}"} + + # 3. resolve branch + key (fall back to agent name for single-agent stability) + branch = tool_context.branch or tool_context.agent_name or "" + key = state_key(self._prefix, branch) + + # 4. read previous list for the diff + old = get_todos(tool_context.session, branch, self._prefix) + + # 5. clear-on-all-done normalisation + new = todos + if self._clear_on_all_done and new and all(t.status == TodoStatus.COMPLETED for t in new): + new = [] + + # 6. persist. Writing through ``tool_context.state`` records a + # state delta that the framework commits with the function + # response event, so the list survives across runs. + new_payload = [t.model_dump(mode="json", by_alias=True) for t in new] + tool_context.state[key] = json.dumps(new_payload, ensure_ascii=False) + + # 7. assemble message: base nudge + read-only policy hooks + message = self._default_nudge + for hook in self._nudge_hooks: + try: + extra = hook(old, todos) + except Exception as e: # pylint: disable=broad-except + logger.warning("TodoWriteTool nudge hook raised: %s", e) + continue + if extra: + message = f"{message}\n\n{extra}" + + # 8. echo current + previous list for direct front-end consumption + return { + "message": message, + "todos": new_payload, + "oldTodos": [t.model_dump(mode="json", by_alias=True) for t in old] if old else None, + } diff --git a/trpc_agent_sdk/tools/task_tools/__init__.py b/trpc_agent_sdk/tools/task_tools/__init__.py new file mode 100644 index 0000000..5f7b06c --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/__init__.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Task tool family. + +Structured, incrementally-updated task board aligned with Claude Code's +``TaskCreate`` / ``TaskUpdate`` / ``TaskGet`` / ``TaskList`` tools. Tasks are +identified by server-assigned ids, support ``blockedBy`` / ``blocks`` +dependency edges, and are persisted to branch-scoped session state. + +This complements :mod:`trpc_agent_sdk.tools._todo_tool` (whole-list +replacement); mount one or the other depending on the use case. +""" + +from ._helpers import DEFAULT_STATE_KEY_PREFIX +from ._helpers import decode_store +from ._helpers import encode_store +from ._helpers import get_task_store +from ._helpers import render_task_list +from ._helpers import state_key +from ._models import TaskListSummary +from ._models import TaskRecord +from ._models import TaskStatus +from ._models import TaskStore +from ._prompt import DEFAULT_TASK_CREATE_DESCRIPTION +from ._prompt import DEFAULT_TASK_GET_DESCRIPTION +from ._prompt import DEFAULT_TASK_LIST_DESCRIPTION +from ._prompt import DEFAULT_TASK_PROMPT +from ._prompt import DEFAULT_TASK_UPDATE_DESCRIPTION +from ._store import clear_dependency +from ._store import create_task +from ._store import list_summaries +from ._task_create_tool import TaskCreateTool +from ._task_get_tool import TaskGetTool +from ._task_list_tool import TaskListTool +from ._task_toolset import TaskToolSet +from ._task_update_tool import TaskUpdateTool +from ._validators import detect_cycle +from ._validators import validate_status + +__all__ = [ + "TaskStatus", + "TaskRecord", + "TaskStore", + "TaskListSummary", + "TaskCreateTool", + "TaskUpdateTool", + "TaskGetTool", + "TaskListTool", + "TaskToolSet", + "get_task_store", + "decode_store", + "encode_store", + "state_key", + "render_task_list", + "create_task", + "list_summaries", + "clear_dependency", + "detect_cycle", + "validate_status", + "DEFAULT_STATE_KEY_PREFIX", + "DEFAULT_TASK_PROMPT", + "DEFAULT_TASK_CREATE_DESCRIPTION", + "DEFAULT_TASK_UPDATE_DESCRIPTION", + "DEFAULT_TASK_GET_DESCRIPTION", + "DEFAULT_TASK_LIST_DESCRIPTION", +] diff --git a/trpc_agent_sdk/tools/task_tools/_base.py b/trpc_agent_sdk/tools/task_tools/_base.py new file mode 100644 index 0000000..218044d --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_base.py @@ -0,0 +1,88 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Shared base for the Task tools: branch resolution, store I/O, prompt injection.""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any +from typing import Optional +from typing_extensions import override + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.filter import BaseFilter +from trpc_agent_sdk.models import LlmRequest + +from .._base_tool import BaseTool +from ._helpers import DEFAULT_STATE_KEY_PREFIX +from ._helpers import decode_store +from ._helpers import encode_store +from ._helpers import state_key +from ._lock import task_store_lock +from ._models import TaskStore +from ._prompt import DEFAULT_TASK_PROMPT +from ._prompt import _PROMPT_MARKER + + +class _TaskToolBase(BaseTool): + """Common plumbing shared by the four Task tools. + + Handles branch-scoped state-key resolution, store load / save, and + one-time injection of :data:`DEFAULT_TASK_PROMPT` into the system + instruction (guarded so mounting several task tools does not duplicate + the guidance). + """ + + def __init__( + self, + *, + name: str, + description: str, + state_key_prefix: str = DEFAULT_STATE_KEY_PREFIX, + inject_prompt: bool = True, + filters_name: Optional[list[str]] = None, + filters: Optional[list[BaseFilter]] = None, + ) -> None: + super().__init__( + name=name, + description=description, + filters_name=filters_name, + filters=filters, + ) + self._prefix = state_key_prefix or DEFAULT_STATE_KEY_PREFIX + self._inject_prompt = inject_prompt + + def _resolve_branch(self, tool_context: InvocationContext) -> str: + return tool_context.branch or tool_context.agent_name or "" + + def _load_store(self, tool_context: InvocationContext) -> TaskStore: + branch = self._resolve_branch(tool_context) + return decode_store(tool_context.state.get(state_key(self._prefix, branch))) + + def _save_store(self, tool_context: InvocationContext, store: TaskStore) -> None: + branch = self._resolve_branch(tool_context) + tool_context.state[state_key(self._prefix, branch)] = encode_store(store) + + @override + async def _run_async_impl(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + branch = self._resolve_branch(tool_context) + async with task_store_lock(tool_context, prefix=self._prefix, branch=branch): + return await self._run_task_store(tool_context=tool_context, args=args) + + @abstractmethod + async def _run_task_store(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + """Tool logic; called under :func:`task_store_lock` for the branch board.""" + + @override + async def process_request(self, *, tool_context: InvocationContext, llm_request: LlmRequest) -> None: + await super().process_request(tool_context=tool_context, llm_request=llm_request) + if not self._inject_prompt: + return + existing = "" + if llm_request.config and llm_request.config.system_instruction: + existing = str(llm_request.config.system_instruction) + if _PROMPT_MARKER not in existing: + llm_request.append_instructions([DEFAULT_TASK_PROMPT]) diff --git a/trpc_agent_sdk/tools/task_tools/_helpers.py b/trpc_agent_sdk/tools/task_tools/_helpers.py new file mode 100644 index 0000000..7248a5b --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_helpers.py @@ -0,0 +1,98 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""State-key handling, store (de)serialisation and rendering for Task tools.""" + +from __future__ import annotations + +from typing import Any +from typing import List + +from trpc_agent_sdk.log import logger + +from ._models import TaskStatus +from ._models import TaskStore + +# Session-scoped (no ``temp:`` prefix, which BaseSessionService strips and +# never persists). Survives across Runner invocations. +DEFAULT_STATE_KEY_PREFIX = "tasks" + + +def state_key(prefix: str, branch: str) -> str: + """Build the state key, appending ``:`` for sub-agent isolation.""" + prefix = prefix or DEFAULT_STATE_KEY_PREFIX + return prefix if not branch else f"{prefix}:{branch}" + + +def decode_store(raw: Any) -> TaskStore: + """Decode a persisted value (JSON string / dict) into a :class:`TaskStore`. + + Tolerates dirty / legacy data: anything that fails to parse degrades to + an empty store rather than raising. + """ + if not raw: + return TaskStore() + try: + if isinstance(raw, str): + return TaskStore.model_validate_json(raw) + if isinstance(raw, dict): + return TaskStore.model_validate(raw) + except (ValueError, TypeError) as e: + logger.warning("Task tools failed to decode persisted store: %s", e) + return TaskStore() + + +def encode_store(store: TaskStore) -> str: + """Serialise a store to a JSON string (camelCase aliases, non-ASCII kept).""" + return store.model_dump_json(by_alias=True) + + +def get_task_store( + session: Any, + branch: str = "", + prefix: str = DEFAULT_STATE_KEY_PREFIX, +) -> TaskStore: + """Read the current task board for ``branch`` from a session. + + Intended for server-side / REST / audit reads. ``session`` only needs a + ``state`` mapping attribute. Malformed data degrades to an empty store. + """ + state = getattr(session, "state", None) or {} + return decode_store(state.get(state_key(prefix, branch))) + + +def render_task_list(store: TaskStore, *, include_deleted: bool = False) -> str: + """Render a plain-text checklist for CLIs / logs. + + Uses ``✅`` / ``🔄`` / ``⬜`` glyphs; shows blocking dependencies. + The tools themselves never call this. + """ + glyph = { + TaskStatus.COMPLETED: "✅", + TaskStatus.IN_PROGRESS: "🔄", + TaskStatus.PENDING: "⬜", + TaskStatus.DELETED: "—", + } + lines: List[str] = [] + for tid in _sorted_ids(store): + task = store.tasks[tid] + if task.status == TaskStatus.DELETED and not include_deleted: + continue + mark = glyph.get(task.status, "⬜") + text = task.active_form if (task.status == TaskStatus.IN_PROGRESS and task.active_form) else task.subject + suffix = f" (blocked by: {', '.join(task.blocked_by)})" if task.blocked_by else "" + lines.append(f"{mark} #{task.id} {text}{suffix}") + return "\n".join(lines) + + +def _sorted_ids(store: TaskStore) -> List[str]: + + def _key(tid: str) -> Any: + try: + return (0, int(tid)) + except ValueError: + return (1, tid) + + return sorted(store.tasks.keys(), key=_key) diff --git a/trpc_agent_sdk/tools/task_tools/_lock.py b/trpc_agent_sdk/tools/task_tools/_lock.py new file mode 100644 index 0000000..01787e4 --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_lock.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Per-session / per-branch locks for :class:`TaskStore` read-modify-write.""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncIterator +from typing import Dict +from typing import Tuple + +from trpc_agent_sdk.context import InvocationContext + +from ._helpers import state_key + +_LockKey = Tuple[str, str, str, str] +_locks: Dict[_LockKey, asyncio.Lock] = {} +_registry_lock = asyncio.Lock() + + +def store_lock_key( + tool_context: InvocationContext, + *, + prefix: str, + branch: str, +) -> _LockKey: + """Build a stable key for the branch-scoped task board.""" + session = tool_context.session + return ( + getattr(session, "app_name", "") or "", + getattr(session, "user_id", "") or "", + getattr(session, "id", "") or "", + state_key(prefix, branch), + ) + + +async def _get_lock(key: _LockKey) -> asyncio.Lock: + async with _registry_lock: + lock = _locks.get(key) + if lock is None: + lock = asyncio.Lock() + _locks[key] = lock + return lock + + +@asynccontextmanager +async def task_store_lock( + tool_context: InvocationContext, + *, + prefix: str, + branch: str, +) -> AsyncIterator[None]: + """Serialize load / mutate / save for one task board.""" + lock = await _get_lock(store_lock_key(tool_context, prefix=prefix, branch=branch)) + async with lock: + yield + + +def reset_locks_for_tests() -> None: + """Clear the lock registry (tests only).""" + _locks.clear() diff --git a/trpc_agent_sdk/tools/task_tools/_models.py b/trpc_agent_sdk/tools/task_tools/_models.py new file mode 100644 index 0000000..5cbd711 --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_models.py @@ -0,0 +1,92 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Data models for the Task tool family. + +These mirror Claude Code's structured Task tools (``TaskCreate`` / +``TaskUpdate`` / ``TaskGet`` / ``TaskList``). A task is identified by a +server-assigned ``id`` and updated incrementally, in contrast with the +whole-list-replacement model used by :mod:`trpc_agent_sdk.tools._todo_tool`. + +The full board (:class:`TaskStore`) is serialised as a single JSON blob +into session-level state so it survives across ``Runner.run_async`` +invocations and stays internally consistent on each read-modify-write. +:class:`_TaskToolBase` serialises access per session branch with an +``asyncio.Lock`` so parallel tool calls cannot corrupt ``highwatermark`` +or ``tasks``. +""" + +from __future__ import annotations + +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + + +class TaskStatus(str, Enum): + """Lifecycle state of a single task.""" + + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + # Soft delete: kept in the store (so ids are never reused) but filtered + # out of ``task_list`` by default. + DELETED = "deleted" + + +class TaskRecord(BaseModel): + """A single task entry. + + ``active_form`` / ``blocked_by`` are exposed to the model and persisted + under the camelCase aliases ``activeForm`` / ``blockedBy`` to stay + compatible with Claude Code's schema. + """ + + id: str = Field(description="Server-assigned, monotonically increasing id.") + subject: str = Field(description="Short imperative title, e.g. 'Run tests'.") + description: str = Field(default="", description="Free-form details.") + active_form: Optional[str] = Field( + default=None, + alias="activeForm", + description="Present-continuous form shown while in_progress.", + ) + owner: Optional[str] = Field(default=None, description="Claiming agent / worker id.") + status: TaskStatus = Field(default=TaskStatus.PENDING, description="Lifecycle state.") + blocks: List[str] = Field(default_factory=list, description="Downstream task ids this task blocks.") + blocked_by: List[str] = Field( + default_factory=list, + alias="blockedBy", + description="Upstream task ids that block this task.", + ) + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Arbitrary extension data.") + + model_config = {"populate_by_name": True} + + +class TaskStore(BaseModel): + """The complete in-session task board, serialised into session state.""" + + highwatermark: int = Field(default=0, description="Highest id ever allocated; ids are never reused.") + tasks: Dict[str, TaskRecord] = Field(default_factory=dict, description="Task id -> record.") + + model_config = {"populate_by_name": True} + + +class TaskListSummary(BaseModel): + """Token-optimised summary returned by ``task_list`` (no ``description``).""" + + id: str + subject: str + status: TaskStatus + owner: Optional[str] = None + active_form: Optional[str] = Field(default=None, alias="activeForm") + blocked_by: List[str] = Field(default_factory=list, alias="blockedBy") + + model_config = {"populate_by_name": True} diff --git a/trpc_agent_sdk/tools/task_tools/_prompt.py b/trpc_agent_sdk/tools/task_tools/_prompt.py new file mode 100644 index 0000000..02ab6ca --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_prompt.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tool descriptions and behavioural guidance for the Task tool family.""" + +from __future__ import annotations + +# Per-tool short descriptions fed to the model as part of each function schema. +DEFAULT_TASK_CREATE_DESCRIPTION = """\ +Create a new task on the structured task board and return its assigned id. +- Use this to plan a multi-step task up front: create one task per distinct step. +- The id is assigned by the server and returned in the result; do NOT invent ids. +- Set dependencies and status afterwards with task_update.\ +""" + +DEFAULT_TASK_UPDATE_DESCRIPTION = """\ +Incrementally update one task by id: change status, fields, owner or dependencies. +- Set status to 'in_progress' before starting a task and 'completed' the moment it is done. +- Keep at most one task 'in_progress' at a time. +- Use addBlockedBy/removeBlockedBy (or addBlocks/removeBlocks) to maintain dependencies. +- Set status to 'deleted' to remove a task; its id will not be reused.\ +""" + +DEFAULT_TASK_GET_DESCRIPTION = """\ +Get the full details of a single task by id, including its description and dependencies.\ +""" + +DEFAULT_TASK_LIST_DESCRIPTION = """\ +List all tasks as a compact summary (id, subject, status, owner, blockedBy). +- The summary intentionally omits descriptions to save tokens; use task_get for details.\ +""" + +# Long-form guidance, injected once into the system instruction. +DEFAULT_TASK_PROMPT = """\ +You have access to a structured task board via the tools: task_create, task_update, +task_get and task_list. Use it to plan, track and order multi-step work. + +When to use it: + - Use the task board for work with multiple steps, dependencies between steps, or work + that spans several turns. Plan first by creating tasks, then execute them one by one. + - When the user gives multiple requests, create a task for each before starting. + - Skip it for single trivial steps or purely informational questions. + +How to use it: + - Plan: call task_create once per step. The id comes back in the result — never invent ids. + - Order: declare dependencies with task_update addBlockedBy (an upstream task must complete + first). Do not start a task whose dependencies are unfinished. + - Progress: before working on a task call task_update with status 'in_progress'; mark it + 'completed' the moment it is done. Keep exactly one task 'in_progress' at a time. + - Read back: use task_list to review the board (summaries only) and task_get for the full + details of a specific task. Do not assume task_list returns descriptions. + - Remove a task with task_update status 'deleted'. + +After calling these tools, do not repeat the whole board back to the user — just continue the +work and summarise meaningful changes. +""" + +# Sentinel substring used to avoid injecting the long prompt more than once +# when several task tools are mounted on the same agent. +_PROMPT_MARKER = "structured task board via the tools: task_create" diff --git a/trpc_agent_sdk/tools/task_tools/_store.py b/trpc_agent_sdk/tools/task_tools/_store.py new file mode 100644 index 0000000..02834a3 --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_store.py @@ -0,0 +1,161 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""In-session CRUD over :class:`TaskStore` with dependency maintenance. + +All functions operate on an in-memory :class:`TaskStore`; persistence is +the caller's responsibility (tools write the serialised store back through +``tool_context.state``). The store is the single source of truth for one +branch's task board, so a read-modify-write of the whole blob keeps the +two-way ``blocks`` / ``blocked_by`` edges consistent. +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from ._models import TaskListSummary +from ._models import TaskRecord +from ._models import TaskStatus +from ._models import TaskStore + + +def create_task( + store: TaskStore, + *, + subject: str, + description: str = "", + active_form: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> TaskRecord: + """Allocate a new id and insert a pending task. Mutates ``store``.""" + store.highwatermark += 1 + task_id = str(store.highwatermark) + record = TaskRecord( + id=task_id, + subject=subject, + description=description, + active_form=active_form, + status=TaskStatus.PENDING, + metadata=metadata, + ) + store.tasks[task_id] = record + return record + + +def get_task(store: TaskStore, task_id: str) -> Optional[TaskRecord]: + """Return the record for ``task_id`` or ``None``.""" + return store.tasks.get(task_id) + + +def list_summaries(store: TaskStore, *, include_deleted: bool = False) -> List[TaskListSummary]: + """Return token-optimised summaries, sorted by numeric id.""" + summaries: List[TaskListSummary] = [] + for tid in _sorted_ids(store): + task = store.tasks[tid] + if task.status == TaskStatus.DELETED and not include_deleted: + continue + summaries.append( + TaskListSummary( + id=task.id, + subject=task.subject, + status=task.status, + owner=task.owner, + active_form=task.active_form, + blocked_by=list(task.blocked_by), + )) + return summaries + + +def stats(store: TaskStore) -> Dict[str, int]: + """Count non-deleted tasks by status.""" + counts = {s.value: 0 for s in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS, TaskStatus.COMPLETED)} + for task in store.tasks.values(): + if task.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS, TaskStatus.COMPLETED): + counts[task.status.value] += 1 + return counts + + +def add_blocked_by(store: TaskStore, task_id: str, upstream_ids: List[str]) -> None: + """Add upstream dependencies, maintaining the reverse ``blocks`` edge.""" + task = store.tasks[task_id] + for upstream in upstream_ids: + if upstream == task_id: + continue + if upstream not in task.blocked_by: + task.blocked_by.append(upstream) + up = store.tasks.get(upstream) + if up is not None and task_id not in up.blocks: + up.blocks.append(task_id) + + +def remove_blocked_by(store: TaskStore, task_id: str, upstream_ids: List[str]) -> None: + """Remove upstream dependencies, maintaining the reverse ``blocks`` edge.""" + task = store.tasks[task_id] + remove = set(upstream_ids) + task.blocked_by = [u for u in task.blocked_by if u not in remove] + for upstream in upstream_ids: + up = store.tasks.get(upstream) + if up is not None and task_id in up.blocks: + up.blocks.remove(task_id) + + +def add_blocks(store: TaskStore, task_id: str, downstream_ids: List[str]) -> None: + """Add downstream blocks, maintaining the reverse ``blocked_by`` edge.""" + task = store.tasks[task_id] + for downstream in downstream_ids: + if downstream == task_id: + continue + if downstream not in task.blocks: + task.blocks.append(downstream) + down = store.tasks.get(downstream) + if down is not None and task_id not in down.blocked_by: + down.blocked_by.append(task_id) + + +def remove_blocks(store: TaskStore, task_id: str, downstream_ids: List[str]) -> None: + """Remove downstream blocks, maintaining the reverse ``blocked_by`` edge.""" + task = store.tasks[task_id] + remove = set(downstream_ids) + task.blocks = [d for d in task.blocks if d not in remove] + for downstream in downstream_ids: + down = store.tasks.get(downstream) + if down is not None and task_id in down.blocked_by: + down.blocked_by.remove(task_id) + + +def clear_dependency(store: TaskStore, completed_id: str) -> List[str]: + """Remove ``completed_id`` from every other task's ``blocked_by``. + + Returns the ids of tasks that became fully unblocked (no remaining + ``blocked_by`` and still pending) as a result. + """ + unblocked: List[str] = [] + for tid in _sorted_ids(store): + task = store.tasks[tid] + if completed_id in task.blocked_by: + task.blocked_by.remove(completed_id) + if not task.blocked_by and task.status == TaskStatus.PENDING: + unblocked.append(tid) + # The completed task no longer blocks anything. + completed = store.tasks.get(completed_id) + if completed is not None: + completed.blocks = [] + return unblocked + + +def _sorted_ids(store: TaskStore) -> List[str]: + """Task ids sorted numerically (ids are stringified integers).""" + + def _key(tid: str) -> Any: + try: + return (0, int(tid)) + except ValueError: + return (1, tid) + + return sorted(store.tasks.keys(), key=_key) diff --git a/trpc_agent_sdk/tools/task_tools/_task_create_tool.py b/trpc_agent_sdk/tools/task_tools/_task_create_tool.py new file mode 100644 index 0000000..d1b2a0c --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_task_create_tool.py @@ -0,0 +1,98 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""``task_create`` — create a task and return its server-assigned id.""" + +from __future__ import annotations + +from typing import Any +from typing_extensions import override + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type + +from ._base import _TaskToolBase +from ._prompt import DEFAULT_TASK_CREATE_DESCRIPTION +from ._store import create_task + +_TOOL_NAME = "task_create" + + +class TaskCreateTool(_TaskToolBase): + """Create a new task on the branch-scoped task board.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(name=_TOOL_NAME, description=DEFAULT_TASK_CREATE_DESCRIPTION, **kwargs) + + @override + def _get_declaration(self) -> FunctionDeclaration: + return FunctionDeclaration( + name=self.name, + description=DEFAULT_TASK_CREATE_DESCRIPTION, + parameters=Schema( + type=Type.OBJECT, + properties={ + "subject": + Schema( + type=Type.STRING, + description="Short imperative title, e.g. 'Run tests'.", + ), + "description": + Schema( + type=Type.STRING, + description="Optional free-form details for the task.", + ), + "activeForm": + Schema( + type=Type.STRING, + description="Optional present-continuous form, e.g. 'Running tests'.", + ), + "metadata": + Schema( + type=Type.OBJECT, + description="Optional arbitrary key/value extension data.", + ), + }, + required=["subject"], + ), + ) + + @override + async def _run_task_store(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + subject = args.get("subject") + if not isinstance(subject, str) or not subject.strip(): + return {"error": "INVALID_ARGS: `subject` is required and must be a non-empty string"} + + description = args.get("description") or "" + if not isinstance(description, str): + return {"error": "INVALID_ARGS: `description` must be a string"} + + active_form = args.get("activeForm") + if active_form is not None and not isinstance(active_form, str): + return {"error": "INVALID_ARGS: `activeForm` must be a string"} + + metadata = args.get("metadata") + if metadata is not None and not isinstance(metadata, dict): + return {"error": "INVALID_ARGS: `metadata` must be an object"} + + store = self._load_store(tool_context) + record = create_task( + store, + subject=subject, + description=description, + active_form=active_form, + metadata=metadata, + ) + self._save_store(tool_context, store) + + return { + "task": { + "id": record.id, + "subject": record.subject + }, + "message": (f"Task {record.id} created. Use task_update to set status or dependencies."), + } diff --git a/trpc_agent_sdk/tools/task_tools/_task_get_tool.py b/trpc_agent_sdk/tools/task_tools/_task_get_tool.py new file mode 100644 index 0000000..0d41a13 --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_task_get_tool.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""``task_get`` — read the full details of a single task by id.""" + +from __future__ import annotations + +from typing import Any +from typing_extensions import override + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type + +from ._base import _TaskToolBase +from ._prompt import DEFAULT_TASK_GET_DESCRIPTION + +_TOOL_NAME = "task_get" + + +class TaskGetTool(_TaskToolBase): + """Return the complete record (including description) for one task.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(name=_TOOL_NAME, description=DEFAULT_TASK_GET_DESCRIPTION, **kwargs) + + @override + def _get_declaration(self) -> FunctionDeclaration: + return FunctionDeclaration( + name=self.name, + description=DEFAULT_TASK_GET_DESCRIPTION, + parameters=Schema( + type=Type.OBJECT, + properties={ + "taskId": Schema(type=Type.STRING, description="Id of the task to fetch."), + }, + required=["taskId"], + ), + ) + + @override + async def _run_task_store(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + task_id = args.get("taskId") + if not isinstance(task_id, str) or not task_id: + return {"error": "INVALID_ARGS: `taskId` is required and must be a string"} + + store = self._load_store(tool_context) + task = store.tasks.get(task_id) + if task is None: + return {"error": f"NOT_FOUND: task {task_id!r} does not exist"} + + return {"task": task.model_dump(mode="json", by_alias=True)} diff --git a/trpc_agent_sdk/tools/task_tools/_task_list_tool.py b/trpc_agent_sdk/tools/task_tools/_task_list_tool.py new file mode 100644 index 0000000..3108e1c --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_task_list_tool.py @@ -0,0 +1,57 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""``task_list`` — list all tasks as a compact, token-optimised summary.""" + +from __future__ import annotations + +from typing import Any +from typing_extensions import override + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type + +from ._base import _TaskToolBase +from ._prompt import DEFAULT_TASK_LIST_DESCRIPTION +from ._store import list_summaries +from ._store import stats + +_TOOL_NAME = "task_list" + + +class TaskListTool(_TaskToolBase): + """List task summaries (no descriptions) plus per-status counts.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(name=_TOOL_NAME, description=DEFAULT_TASK_LIST_DESCRIPTION, **kwargs) + + @override + def _get_declaration(self) -> FunctionDeclaration: + return FunctionDeclaration( + name=self.name, + description=DEFAULT_TASK_LIST_DESCRIPTION, + parameters=Schema( + type=Type.OBJECT, + properties={ + "includeDeleted": + Schema( + type=Type.BOOLEAN, + description="Include soft-deleted tasks in the list (default false).", + ), + }, + ), + ) + + @override + async def _run_task_store(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + include_deleted = bool(args.get("includeDeleted", False)) + store = self._load_store(tool_context) + summaries = list_summaries(store, include_deleted=include_deleted) + return { + "tasks": [s.model_dump(mode="json", by_alias=True) for s in summaries], + "stats": stats(store), + } diff --git a/trpc_agent_sdk/tools/task_tools/_task_toolset.py b/trpc_agent_sdk/tools/task_tools/_task_toolset.py new file mode 100644 index 0000000..baf702b --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_task_toolset.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""``TaskToolSet`` — bundles the four Task tools as a single toolset.""" + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing_extensions import override + +from trpc_agent_sdk.abc import ToolSetABC +from trpc_agent_sdk.context import InvocationContext + +from .._base_tool import BaseTool +from ._helpers import DEFAULT_STATE_KEY_PREFIX +from ._task_create_tool import TaskCreateTool +from ._task_get_tool import TaskGetTool +from ._task_list_tool import TaskListTool +from ._task_update_tool import TaskUpdateTool + + +class TaskToolSet(ToolSetABC): + """Toolset exposing ``task_create`` / ``task_update`` / ``task_get`` / ``task_list``. + + The structured task board aligns with Claude Code's Task tools: tasks are + created with server-assigned ids and updated incrementally by id, with + ``blockedBy`` / ``blocks`` dependency edges. The board is persisted to + branch-scoped session state and survives across Runner invocations. + + Args: + state_key_prefix: State-key prefix; ``tasks`` by default. Avoid + ``temp:`` — that prefix is invocation-only and is not stored. + enforce_single_in_progress: Reject setting a task ``in_progress`` + while another already is (default ``True``). + inject_prompt: Inject :data:`DEFAULT_TASK_PROMPT` into the system + instruction once (default ``True``). + """ + + def __init__( + self, + *, + state_key_prefix: str = DEFAULT_STATE_KEY_PREFIX, + enforce_single_in_progress: bool = True, + inject_prompt: bool = True, + name: str = "task_toolset", + ) -> None: + super().__init__(name=name) + self._prefix = state_key_prefix or DEFAULT_STATE_KEY_PREFIX + self._enforce_single_in_progress = bool(enforce_single_in_progress) + self._inject_prompt = bool(inject_prompt) + + @override + async def get_tools(self, invocation_context: Optional[InvocationContext] = None) -> List[BaseTool]: + return [ + TaskCreateTool(state_key_prefix=self._prefix, inject_prompt=self._inject_prompt), + TaskUpdateTool( + state_key_prefix=self._prefix, + inject_prompt=self._inject_prompt, + enforce_single_in_progress=self._enforce_single_in_progress, + ), + TaskGetTool(state_key_prefix=self._prefix, inject_prompt=self._inject_prompt), + TaskListTool(state_key_prefix=self._prefix, inject_prompt=self._inject_prompt), + ] diff --git a/trpc_agent_sdk/tools/task_tools/_task_update_tool.py b/trpc_agent_sdk/tools/task_tools/_task_update_tool.py new file mode 100644 index 0000000..3703332 --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_task_update_tool.py @@ -0,0 +1,184 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""``task_update`` — incrementally patch one task by id.""" + +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Optional +from typing_extensions import override + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.types import FunctionDeclaration +from trpc_agent_sdk.types import Schema +from trpc_agent_sdk.types import Type + +from ._base import _TaskToolBase +from ._models import TaskStatus +from ._prompt import DEFAULT_TASK_UPDATE_DESCRIPTION +from ._store import add_blocked_by +from ._store import add_blocks +from ._store import clear_dependency +from ._store import remove_blocked_by +from ._store import remove_blocks +from ._validators import detect_cycle +from ._validators import validate_dependencies_exist +from ._validators import validate_status + +_TOOL_NAME = "task_update" + + +def _as_str_list(value: Any) -> Optional[List[str]]: + """Coerce an arg into a list of string ids, or ``None`` if absent/invalid.""" + if value is None: + return None + if not isinstance(value, list): + return None + return [str(v) for v in value] + + +class TaskUpdateTool(_TaskToolBase): + """Update a single task by id: status, fields, owner or dependencies. + + Args: + enforce_single_in_progress: When ``True`` (default), reject setting a + task ``in_progress`` while another task already is. + """ + + def __init__(self, *, enforce_single_in_progress: bool = True, **kwargs: Any) -> None: + super().__init__(name=_TOOL_NAME, description=DEFAULT_TASK_UPDATE_DESCRIPTION, **kwargs) + self._enforce_single_in_progress = bool(enforce_single_in_progress) + + @override + def _get_declaration(self) -> FunctionDeclaration: + id_array = Schema(type=Type.ARRAY, items=Schema(type=Type.STRING)) + return FunctionDeclaration( + name=self.name, + description=DEFAULT_TASK_UPDATE_DESCRIPTION, + parameters=Schema( + type=Type.OBJECT, + properties={ + "taskId": + Schema(type=Type.STRING, description="Id of the task to update."), + "status": + Schema( + type=Type.STRING, + enum=["pending", "in_progress", "completed", "deleted"], + description="New status. Keep at most one task 'in_progress'.", + ), + "subject": + Schema(type=Type.STRING, description="New subject."), + "description": + Schema(type=Type.STRING, description="New description."), + "activeForm": + Schema(type=Type.STRING, description="New present-continuous form."), + "owner": + Schema(type=Type.STRING, description="Claiming agent / worker id."), + "addBlockedBy": + Schema(type=Type.ARRAY, + items=Schema(type=Type.STRING), + description="Upstream task ids to add as dependencies."), + "removeBlockedBy": + id_array, + "addBlocks": + Schema(type=Type.ARRAY, + items=Schema(type=Type.STRING), + description="Downstream task ids this task should block."), + "removeBlocks": + id_array, + "metadata": + Schema(type=Type.OBJECT, description="Replacement extension data."), + }, + required=["taskId"], + ), + ) + + @override + async def _run_task_store(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + task_id = args.get("taskId") + if not isinstance(task_id, str) or not task_id: + return {"error": "INVALID_ARGS: `taskId` is required and must be a string"} + + store = self._load_store(tool_context) + task = store.tasks.get(task_id) + if task is None: + return {"error": f"NOT_FOUND: task {task_id!r} does not exist"} + + new_status_raw = args.get("status") + if new_status_raw is not None: + if (err := validate_status(new_status_raw)) is not None: + return {"error": f"INVALID_ARGS: {err}"} + if task.status == TaskStatus.DELETED and new_status_raw not in (None, TaskStatus.DELETED.value): + return {"error": f"INVALID_ARGS: task {task_id!r} is deleted and cannot be modified"} + + # Validate dependency reference existence before mutating anything. + add_blocked = _as_str_list(args.get("addBlockedBy")) or [] + remove_blocked = _as_str_list(args.get("removeBlockedBy")) or [] + add_blk = _as_str_list(args.get("addBlocks")) or [] + remove_blk = _as_str_list(args.get("removeBlocks")) or [] + for ref in (add_blocked, add_blk): + if (err := validate_dependencies_exist(store, ref)) is not None: + return {"error": f"INVALID_DEPENDENCY: {err}"} + + # Single-in-progress guard: setting this task in_progress must not + # leave two tasks in_progress at once. + if (self._enforce_single_in_progress and new_status_raw == TaskStatus.IN_PROGRESS.value): + others = [tid for tid, t in store.tasks.items() if t.status == TaskStatus.IN_PROGRESS and tid != task_id] + if others: + return {"error": f"INVALID_STATUS: task {others[0]!r} is already in_progress"} + + # Apply scalar field patches. + if "subject" in args and args["subject"] is not None: + if not isinstance(args["subject"], str) or not args["subject"].strip(): + return {"error": "INVALID_ARGS: `subject` must be a non-empty string"} + task.subject = args["subject"] + if "description" in args and args["description"] is not None: + task.description = str(args["description"]) + if "activeForm" in args and args["activeForm"] is not None: + task.active_form = str(args["activeForm"]) + if "owner" in args and args["owner"] is not None: + task.owner = str(args["owner"]) + if "metadata" in args and args["metadata"] is not None: + if not isinstance(args["metadata"], dict): + return {"error": "INVALID_ARGS: `metadata` must be an object"} + task.metadata = args["metadata"] + + # Apply dependency edits (two-way edges maintained by the store). + if add_blocked: + add_blocked_by(store, task_id, add_blocked) + if remove_blocked: + remove_blocked_by(store, task_id, remove_blocked) + if add_blk: + add_blocks(store, task_id, add_blk) + if remove_blk: + remove_blocks(store, task_id, remove_blk) + + if (err := detect_cycle(store)) is not None: + # Reject the whole update; nothing has been persisted yet. + return {"error": f"INVALID_DEPENDENCY: {err}"} + + # Apply status last so unblock computation reflects edited deps. + unblocked: List[str] = [] + if new_status_raw is not None: + task.status = TaskStatus(new_status_raw) + if task.status == TaskStatus.COMPLETED: + unblocked = clear_dependency(store, task_id) + + self._save_store(tool_context, store) + + return { + "task": task.model_dump(mode="json", by_alias=True), + "unblocked": unblocked, + "message": _build_message(task_id, task.status, unblocked), + } + + +def _build_message(task_id: str, status: TaskStatus, unblocked: List[str]) -> str: + msg = f"Task {task_id} updated (status={status.value})." + if unblocked: + msg += f" Unblocked: {', '.join(unblocked)}." + return msg diff --git a/trpc_agent_sdk/tools/task_tools/_validators.py b/trpc_agent_sdk/tools/task_tools/_validators.py new file mode 100644 index 0000000..25c7c3f --- /dev/null +++ b/trpc_agent_sdk/tools/task_tools/_validators.py @@ -0,0 +1,106 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Hard-contract validators for the Task tool family. + +These enforce structural invariants in code (well-formed input, valid +status transitions, acyclic dependency graph, optional single +``in_progress``). Softer style guidance lives in :mod:`._prompt`. +""" + +from __future__ import annotations + +from typing import List +from typing import Optional + +from ._models import TaskRecord +from ._models import TaskStatus +from ._models import TaskStore + +# Statuses a model may set via task_update. +_ASSIGNABLE_STATUSES = { + TaskStatus.PENDING, + TaskStatus.IN_PROGRESS, + TaskStatus.COMPLETED, + TaskStatus.DELETED, +} + + +def validate_status(status: str) -> Optional[str]: + """Return an error string if ``status`` is not assignable, else ``None``.""" + try: + parsed = TaskStatus(status) + except ValueError: + valid = ", ".join(s.value for s in _ASSIGNABLE_STATUSES) + return f"invalid status {status!r}; must be one of: {valid}" + if parsed not in _ASSIGNABLE_STATUSES: + return f"status {status!r} is not assignable" + return None + + +def validate_single_in_progress(store: TaskStore, exclude_id: Optional[str] = None) -> Optional[str]: + """Return an error if more than one non-deleted task is ``in_progress``. + + ``exclude_id`` is ignored when counting so the caller can validate a + prospective state for a single task it is about to update. + """ + in_progress = [tid for tid, t in store.tasks.items() if t.status == TaskStatus.IN_PROGRESS and tid != exclude_id] + if len(in_progress) > 1: + return f"at most one task may be in_progress (found {in_progress})" + return None + + +def detect_cycle(store: TaskStore) -> Optional[str]: + """Detect a cycle in the ``blocked_by`` dependency graph. + + Edge semantics: ``A.blocked_by = [B]`` means B must complete before A, + i.e. a directed edge ``A -> B``. A cycle means a circular dependency + that can never be satisfied. Deleted tasks are skipped. Returns an + error string naming a task on the cycle, or ``None`` when acyclic. + """ + # 0 = unvisited, 1 = on current DFS stack, 2 = fully explored. + color: dict[str, int] = {} + + def visit(node: str) -> Optional[str]: + task = store.tasks.get(node) + if task is None or task.status == TaskStatus.DELETED: + color[node] = 2 + return None + color[node] = 1 + for dep in task.blocked_by: + state = color.get(dep, 0) + if state == 1: + return f"dependency cycle detected involving task {dep!r}" + if state == 0: + err = visit(dep) + if err is not None: + return err + color[node] = 2 + return None + + for tid in store.tasks: + if color.get(tid, 0) == 0: + err = visit(tid) + if err is not None: + return err + return None + + +def validate_dependencies_exist(store: TaskStore, ids: List[str]) -> Optional[str]: + """Return an error if any id in ``ids`` is missing or deleted.""" + for tid in ids: + task = store.tasks.get(tid) + if task is None: + return f"referenced task {tid!r} does not exist" + if task.status == TaskStatus.DELETED: + return f"referenced task {tid!r} is deleted" + return None + + +def validate_task(task: TaskRecord) -> Optional[str]: + """Validate a single record's basic field contract.""" + if not task.subject or not task.subject.strip(): + return "subject must not be empty" + return None