From 957406ca37afa850f3ed014f5e4c8b404acb7057 Mon Sep 17 00:00:00 2001
From: raychen <815315825@qq.com>
Date: Thu, 25 Jun 2026 10:57:13 +0800
Subject: [PATCH] =?UTF-8?q?bugifx:=20=E4=BF=AE=E5=A4=8Dtool=E6=96=87?=
=?UTF-8?q?=E6=9C=AC=E5=92=8C=E6=AD=A3=E5=B8=B8=E6=96=87=E6=9C=AC=E5=90=8C?=
=?UTF-8?q?=E6=97=B6=E5=AD=98=E5=9C=A8=E7=9A=84=E6=97=B6=E5=80=99=E6=96=87?=
=?UTF-8?q?=E6=9C=AC=E4=B8=A2=E5=A4=B1=E7=9A=84=E8=BE=B9=E7=95=8C=E9=97=AE?=
=?UTF-8?q?=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
场景:
当模型不支持tool调用的时候,框架会做tool的适配,主要是通过文本解析的方式,框架会过滤这个文本避免被当做正常的模型推理数据,当tool和正常文本同时存在的时候,这个时候会把正常文本过滤掉
解决:
对正常文本和模型推理的文本做进一步的区分,避免漏了数据
---
tests/models/test_openai_model.py | 97 ++++++++++++++++++++++++++
trpc_agent_sdk/models/_openai_model.py | 14 +++-
2 files changed, 108 insertions(+), 3 deletions(-)
diff --git a/tests/models/test_openai_model.py b/tests/models/test_openai_model.py
index 2d4efd9..2cb910d 100644
--- a/tests/models/test_openai_model.py
+++ b/tests/models/test_openai_model.py
@@ -625,6 +625,103 @@ async def mock_stream():
# Should have multiple partial responses plus final response
assert len(responses) > 1
+ async def test_generate_async_streaming_preserves_text_with_native_tool_call(self):
+ """Final streaming response keeps regular text alongside native tool calls."""
+ model = OpenAIModel(model_name="gpt-4", api_key="test_key")
+ content = Content(parts=[Part.from_text(text="What's the weather?")], role="user")
+ request = LlmRequest(contents=[content], config=None, tools_dict={})
+ text_chunk = Mock()
+ text_chunk.model_dump.return_value = {
+ "choices": [{
+ "delta": {
+ "content": "I'll check the weather first."
+ },
+ "finish_reason": None,
+ "index": 0,
+ }],
+ "usage": None,
+ }
+ tool_chunk = Mock()
+ tool_chunk.model_dump.return_value = {
+ "choices": [{
+ "delta": {
+ "tool_calls": [{
+ "index": 0,
+ "id": "call_weather",
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "arguments": '{"city": "Beijing"}',
+ },
+ }]
+ },
+ "finish_reason": "tool_calls",
+ "index": 0,
+ }],
+ "usage": None,
+ }
+ async def mock_stream():
+ yield text_chunk
+ yield tool_chunk
+ with patch.object(model, '_create_async_client') as mock_client_factory:
+ mock_client = AsyncMock()
+ mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
+ mock_client.close = AsyncMock()
+ mock_client_factory.return_value = mock_client
+ responses = []
+ async for response in model.generate_async(request, stream=True):
+ responses.append(response)
+ final_response = responses[-1]
+ assert final_response.partial is False
+ assert final_response.content is not None
+ text_parts = [part.text for part in final_response.content.parts if part.text]
+ function_parts = [part.function_call for part in final_response.content.parts if part.function_call]
+ assert text_parts == ["I'll check the weather first."]
+ assert len(function_parts) == 1
+ assert function_parts[0].name == "get_weather"
+ assert function_parts[0].args == {"city": "Beijing"}
+
+ @pytest.mark.asyncio
+ async def test_generate_async_streaming_suppresses_tool_prompt_markup_but_keeps_visible_text(self):
+ """Provider tool-prompt markup is hidden from final text while visible text is preserved."""
+ model = OpenAIModel(model_name="hy3-preview", api_key="test_key", add_tools_to_prompt=True)
+ content = Content(parts=[Part.from_text(text="What's the weather?")], role="user")
+ request = LlmRequest(contents=[content], config=None, tools_dict={})
+ tool_prompt_chunk = Mock()
+ tool_prompt_chunk.model_dump.return_value = {
+ "choices": [{
+ "delta": {
+ "content": ("I'll check the weather first. "
+ "get_weather"
+ "cityBeijing"
+ "")
+ },
+ "finish_reason": "stop",
+ "index": 0,
+ }],
+ "usage": None,
+ }
+ async def mock_stream():
+ yield tool_prompt_chunk
+ with patch.object(model, '_create_async_client') as mock_client_factory:
+ mock_client = AsyncMock()
+ mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
+ mock_client.close = AsyncMock()
+ mock_client_factory.return_value = mock_client
+ responses = []
+ async for response in model.generate_async(request, stream=True):
+ responses.append(response)
+ final_response = responses[-1]
+ assert final_response.partial is False
+ assert final_response.content is not None
+ text = "".join(part.text for part in final_response.content.parts if part.text)
+ function_parts = [part.function_call for part in final_response.content.parts if part.function_call]
+ assert text == "I'll check the weather first. "
+ assert "" not in text
+ assert len(function_parts) == 1
+ assert function_parts[0].name == "get_weather"
+ assert function_parts[0].args == {"city": "Beijing"}
+
@pytest.mark.asyncio
async def test_generate_async_error_handling(self):
"""Test generate_async handles API errors gracefully."""
diff --git a/trpc_agent_sdk/models/_openai_model.py b/trpc_agent_sdk/models/_openai_model.py
index d07ed68..bd30ae3 100644
--- a/trpc_agent_sdk/models/_openai_model.py
+++ b/trpc_agent_sdk/models/_openai_model.py
@@ -1548,6 +1548,7 @@ async def _generate_stream(self,
http_options = {}
thought_content = ""
accumulated_content = ""
+ visible_content = ""
last_usage = None
accumulated_tool_calls: list[dict] = []
is_thinking = False # Track whether we're currently in thinking mode
@@ -1668,6 +1669,8 @@ async def _generate_stream(self,
partial_text = self._adapter.filter_streaming_text(content, content_filter_state)
if not partial_text:
continue
+ if not is_thinking:
+ visible_content += partial_text
# Set thought flag based on current thinking state
content_part = Part.from_text(text=partial_text)
@@ -1700,6 +1703,8 @@ async def _generate_stream(self,
flushed_content_text = self._adapter.flush_streaming_text(streaming_text_filter_state["content"])
if flushed_content_text:
+ if not is_thinking:
+ visible_content += flushed_content_text
content_part = Part.from_text(text=flushed_content_text)
content_part.thought = is_thinking
partial_content = Content(parts=[content_part], role=const.MODEL)
@@ -1739,9 +1744,12 @@ async def _generate_stream(self,
logger.warning("Failed to parse function calls from final accumulated content: %s", ex)
# Add text content if present
- if accumulated_content and not complete_tool_calls:
- logger.debug("Final accumulated regular content: %s...", accumulated_content[:200])
- content_part = Part.from_text(text=accumulated_content)
+ final_text_content = accumulated_content
+ if complete_tool_calls and self._adapter.should_suppress_tool_prompt_text():
+ final_text_content = visible_content
+ if final_text_content:
+ logger.debug("Final accumulated regular content: %s...", final_text_content[:200])
+ content_part = Part.from_text(text=final_text_content)
content_part.thought = False # Final accumulated content represents the answer, not thinking
parts.append(content_part)