Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions tests/models/test_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,103 @@ async def mock_stream():
# Should have multiple partial responses plus final response
assert len(responses) > 1

async def test_generate_async_streaming_preserves_text_with_native_tool_call(self):
"""Final streaming response keeps regular text alongside native tool calls."""
model = OpenAIModel(model_name="gpt-4", api_key="test_key")
content = Content(parts=[Part.from_text(text="What's the weather?")], role="user")
request = LlmRequest(contents=[content], config=None, tools_dict={})
text_chunk = Mock()
text_chunk.model_dump.return_value = {
"choices": [{
"delta": {
"content": "I'll check the weather first."
},
"finish_reason": None,
"index": 0,
}],
"usage": None,
}
tool_chunk = Mock()
tool_chunk.model_dump.return_value = {
"choices": [{
"delta": {
"tool_calls": [{
"index": 0,
"id": "call_weather",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Beijing"}',
},
}]
},
"finish_reason": "tool_calls",
"index": 0,
}],
"usage": None,
}
async def mock_stream():
yield text_chunk
yield tool_chunk
with patch.object(model, '_create_async_client') as mock_client_factory:
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
mock_client.close = AsyncMock()
mock_client_factory.return_value = mock_client
responses = []
async for response in model.generate_async(request, stream=True):
responses.append(response)
final_response = responses[-1]
assert final_response.partial is False
assert final_response.content is not None
text_parts = [part.text for part in final_response.content.parts if part.text]
function_parts = [part.function_call for part in final_response.content.parts if part.function_call]
assert text_parts == ["I'll check the weather first."]
assert len(function_parts) == 1
assert function_parts[0].name == "get_weather"
assert function_parts[0].args == {"city": "Beijing"}

@pytest.mark.asyncio
async def test_generate_async_streaming_suppresses_tool_prompt_markup_but_keeps_visible_text(self):
"""Provider tool-prompt markup is hidden from final text while visible text is preserved."""
model = OpenAIModel(model_name="hy3-preview", api_key="test_key", add_tools_to_prompt=True)
content = Content(parts=[Part.from_text(text="What's the weather?")], role="user")
request = LlmRequest(contents=[content], config=None, tools_dict={})
tool_prompt_chunk = Mock()
tool_prompt_chunk.model_dump.return_value = {
"choices": [{
"delta": {
"content": ("I'll check the weather first. "
"<tool_call>get_weather<tool_sep>"
"<arg_key>city</arg_key><arg_value>Beijing</arg_value>"
"</tool_call>")
},
"finish_reason": "stop",
"index": 0,
}],
"usage": None,
}
async def mock_stream():
yield tool_prompt_chunk
with patch.object(model, '_create_async_client') as mock_client_factory:
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
mock_client.close = AsyncMock()
mock_client_factory.return_value = mock_client
responses = []
async for response in model.generate_async(request, stream=True):
responses.append(response)
final_response = responses[-1]
assert final_response.partial is False
assert final_response.content is not None
text = "".join(part.text for part in final_response.content.parts if part.text)
function_parts = [part.function_call for part in final_response.content.parts if part.function_call]
assert text == "I'll check the weather first. "
assert "<tool_call>" not in text
assert len(function_parts) == 1
assert function_parts[0].name == "get_weather"
assert function_parts[0].args == {"city": "Beijing"}

@pytest.mark.asyncio
async def test_generate_async_error_handling(self):
"""Test generate_async handles API errors gracefully."""
Expand Down
14 changes: 11 additions & 3 deletions trpc_agent_sdk/models/_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,7 @@ async def _generate_stream(self,
http_options = {}
thought_content = ""
accumulated_content = ""
visible_content = ""
last_usage = None
accumulated_tool_calls: list[dict] = []
is_thinking = False # Track whether we're currently in thinking mode
Expand Down Expand Up @@ -1668,6 +1669,8 @@ async def _generate_stream(self,
partial_text = self._adapter.filter_streaming_text(content, content_filter_state)
if not partial_text:
continue
if not is_thinking:
visible_content += partial_text

# Set thought flag based on current thinking state
content_part = Part.from_text(text=partial_text)
Expand Down Expand Up @@ -1700,6 +1703,8 @@ async def _generate_stream(self,

flushed_content_text = self._adapter.flush_streaming_text(streaming_text_filter_state["content"])
if flushed_content_text:
if not is_thinking:
visible_content += flushed_content_text
content_part = Part.from_text(text=flushed_content_text)
content_part.thought = is_thinking
partial_content = Content(parts=[content_part], role=const.MODEL)
Expand Down Expand Up @@ -1739,9 +1744,12 @@ async def _generate_stream(self,
logger.warning("Failed to parse function calls from final accumulated content: %s", ex)

# Add text content if present
if accumulated_content and not complete_tool_calls:
logger.debug("Final accumulated regular content: %s...", accumulated_content[:200])
content_part = Part.from_text(text=accumulated_content)
final_text_content = accumulated_content
if complete_tool_calls and self._adapter.should_suppress_tool_prompt_text():
final_text_content = visible_content
if final_text_content:
logger.debug("Final accumulated regular content: %s...", final_text_content[:200])
content_part = Part.from_text(text=final_text_content)
content_part.thought = False # Final accumulated content represents the answer, not thinking
parts.append(content_part)

Expand Down
Loading