diff --git a/tests/models/test_openai_model.py b/tests/models/test_openai_model.py index 2d4efd9..2cb910d 100644 --- a/tests/models/test_openai_model.py +++ b/tests/models/test_openai_model.py @@ -625,6 +625,103 @@ async def mock_stream(): # Should have multiple partial responses plus final response assert len(responses) > 1 + async def test_generate_async_streaming_preserves_text_with_native_tool_call(self): + """Final streaming response keeps regular text alongside native tool calls.""" + model = OpenAIModel(model_name="gpt-4", api_key="test_key") + content = Content(parts=[Part.from_text(text="What's the weather?")], role="user") + request = LlmRequest(contents=[content], config=None, tools_dict={}) + text_chunk = Mock() + text_chunk.model_dump.return_value = { + "choices": [{ + "delta": { + "content": "I'll check the weather first." + }, + "finish_reason": None, + "index": 0, + }], + "usage": None, + } + tool_chunk = Mock() + tool_chunk.model_dump.return_value = { + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "id": "call_weather", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Beijing"}', + }, + }] + }, + "finish_reason": "tool_calls", + "index": 0, + }], + "usage": None, + } + async def mock_stream(): + yield text_chunk + yield tool_chunk + with patch.object(model, '_create_async_client') as mock_client_factory: + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock(return_value=mock_stream()) + mock_client.close = AsyncMock() + mock_client_factory.return_value = mock_client + responses = [] + async for response in model.generate_async(request, stream=True): + responses.append(response) + final_response = responses[-1] + assert final_response.partial is False + assert final_response.content is not None + text_parts = [part.text for part in final_response.content.parts if part.text] + function_parts = [part.function_call for part in final_response.content.parts if part.function_call] + assert text_parts == ["I'll check the weather first."] + assert len(function_parts) == 1 + assert function_parts[0].name == "get_weather" + assert function_parts[0].args == {"city": "Beijing"} + + @pytest.mark.asyncio + async def test_generate_async_streaming_suppresses_tool_prompt_markup_but_keeps_visible_text(self): + """Provider tool-prompt markup is hidden from final text while visible text is preserved.""" + model = OpenAIModel(model_name="hy3-preview", api_key="test_key", add_tools_to_prompt=True) + content = Content(parts=[Part.from_text(text="What's the weather?")], role="user") + request = LlmRequest(contents=[content], config=None, tools_dict={}) + tool_prompt_chunk = Mock() + tool_prompt_chunk.model_dump.return_value = { + "choices": [{ + "delta": { + "content": ("I'll check the weather first. " + "get_weather" + "cityBeijing" + "") + }, + "finish_reason": "stop", + "index": 0, + }], + "usage": None, + } + async def mock_stream(): + yield tool_prompt_chunk + with patch.object(model, '_create_async_client') as mock_client_factory: + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock(return_value=mock_stream()) + mock_client.close = AsyncMock() + mock_client_factory.return_value = mock_client + responses = [] + async for response in model.generate_async(request, stream=True): + responses.append(response) + final_response = responses[-1] + assert final_response.partial is False + assert final_response.content is not None + text = "".join(part.text for part in final_response.content.parts if part.text) + function_parts = [part.function_call for part in final_response.content.parts if part.function_call] + assert text == "I'll check the weather first. " + assert "" not in text + assert len(function_parts) == 1 + assert function_parts[0].name == "get_weather" + assert function_parts[0].args == {"city": "Beijing"} + @pytest.mark.asyncio async def test_generate_async_error_handling(self): """Test generate_async handles API errors gracefully.""" diff --git a/trpc_agent_sdk/models/_openai_model.py b/trpc_agent_sdk/models/_openai_model.py index d07ed68..bd30ae3 100644 --- a/trpc_agent_sdk/models/_openai_model.py +++ b/trpc_agent_sdk/models/_openai_model.py @@ -1548,6 +1548,7 @@ async def _generate_stream(self, http_options = {} thought_content = "" accumulated_content = "" + visible_content = "" last_usage = None accumulated_tool_calls: list[dict] = [] is_thinking = False # Track whether we're currently in thinking mode @@ -1668,6 +1669,8 @@ async def _generate_stream(self, partial_text = self._adapter.filter_streaming_text(content, content_filter_state) if not partial_text: continue + if not is_thinking: + visible_content += partial_text # Set thought flag based on current thinking state content_part = Part.from_text(text=partial_text) @@ -1700,6 +1703,8 @@ async def _generate_stream(self, flushed_content_text = self._adapter.flush_streaming_text(streaming_text_filter_state["content"]) if flushed_content_text: + if not is_thinking: + visible_content += flushed_content_text content_part = Part.from_text(text=flushed_content_text) content_part.thought = is_thinking partial_content = Content(parts=[content_part], role=const.MODEL) @@ -1739,9 +1744,12 @@ async def _generate_stream(self, logger.warning("Failed to parse function calls from final accumulated content: %s", ex) # Add text content if present - if accumulated_content and not complete_tool_calls: - logger.debug("Final accumulated regular content: %s...", accumulated_content[:200]) - content_part = Part.from_text(text=accumulated_content) + final_text_content = accumulated_content + if complete_tool_calls and self._adapter.should_suppress_tool_prompt_text(): + final_text_content = visible_content + if final_text_content: + logger.debug("Final accumulated regular content: %s...", final_text_content[:200]) + content_part = Part.from_text(text=final_text_content) content_part.thought = False # Final accumulated content represents the answer, not thinking parts.append(content_part)