From 341f56f80f1e8d0161d7f23384b71fc0cbb3fda1 Mon Sep 17 00:00:00 2001 From: raychen <815315825@qq.com> Date: Tue, 23 Jun 2026 16:17:54 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E5=86=99http=20client?= =?UTF-8?q?=E4=BC=A0=E5=85=A5=E6=A8=A1=E5=9E=8B=E7=9A=84=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持用户传入http client 的构建对象 - 默认的构建对象是临时的http client, 如果期望共享http client可以自行配置 --- docs/mkdocs/en/model.md | 25 ++ docs/mkdocs/zh/model.md | 25 ++ tests/models/test_anthropic_model.py | 4 +- tests/models/test_anthropic_model_ext.py | 4 +- tests/models/test_openai_model.py | 188 ++++++---- trpc_agent_sdk/agents/_base_agent.py | 4 +- trpc_agent_sdk/models/__init__.py | 8 + trpc_agent_sdk/models/_anthropic_model.py | 20 +- trpc_agent_sdk/models/_httpx_client.py | 144 ++++++++ trpc_agent_sdk/models/_openai_model.py | 418 +++++++++++----------- trpc_agent_sdk/runners.py | 4 +- trpc_agent_sdk/telemetry/_metrics.py | 2 +- 12 files changed, 554 insertions(+), 292 deletions(-) create mode 100644 trpc_agent_sdk/models/_httpx_client.py diff --git a/docs/mkdocs/en/model.md b/docs/mkdocs/en/model.md index 1a8b6203..5b9618e7 100644 --- a/docs/mkdocs/en/model.md +++ b/docs/mkdocs/en/model.md @@ -121,6 +121,27 @@ model = OpenAIModel( ) ``` +#### Advanced Usage + +Since version `1.1.10`, `OpenAIModel` supports passing a shared HTTP client provider to enable connection reuse. By default, `OpenAIModel` creates a temporary HTTP client for each model-service request. If you want to reuse connections, use the following configuration: + +```python +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.models import shared_http_client_provider_factory +# from trpc_agent_sdk.models import close_shared_http_clients + +model = OpenAIModel( + model_name="deepseek-chat", + api_key="your-api-key", + base_url="https://api.deepseek.com/v1", + http_client_provider_factory=shared_http_client_provider_factory, +) +# ... + +# If you need to force-close the shared HTTP clients, use: +# await close_shared_http_clients() +``` + ### Integration with Various Platform Model Services: #### Hunyuan Model Invocation @@ -204,6 +225,10 @@ LlmAgent( ) ``` +#### Advanced Usage + +Since version `1.1.10`, `AnthropicModel` supports passing a shared HTTP client provider to enable connection reuse. By default, `AnthropicModel` creates a temporary HTTP client for each model-service request. See the Advanced Usage section under `OpenAIModel` for an example. + ## LiteLLMModel As multiple LLM providers have emerged, some have defined their own API specifications. Currently, the framework has integrated OpenAI and Anthropic APIs as described above. However, differences in instantiation methods and configuration options across providers mean that developers often need to modify substantial amounts of code when switching providers, increasing the switching cost. To address this issue, tRPC-Agent supports unified multi-provider model access through [LiteLLM](https://docs.litellm.ai/), using the **provider/model** format (e.g., `openai/gpt-4o`, `anthropic/claude-3-5-sonnet`, `gemini/gemini-1.5-pro`), enabling switching between different backends with a single invocation pattern. LiteLLMModel inherits from OpenAIModel and only overrides the API call path to `litellm.acompletion`, simplifying the complexity of provider switching. diff --git a/docs/mkdocs/zh/model.md b/docs/mkdocs/zh/model.md index f8e9944b..16e2320b 100644 --- a/docs/mkdocs/zh/model.md +++ b/docs/mkdocs/zh/model.md @@ -121,6 +121,27 @@ model = OpenAIModel( ) ``` +#### 高级用法 + +从版本 `1.1.10`之后 OpenAIModel 支持传入共享的 http client 来解决连接复用的场景,当前的 OpenAIModel 默认每次都会创建临时的 http client 去访问模型服务;如果期望连接复用可以使用如下的方式 + +```python +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.models import shared_http_client_provider_factory +#from trpc_agent_sdk.models import close_shared_http_clients + +model = OpenAIModel( + model_name="deepseek-chat", + api_key="your-api-key", + base_url="https://api.deepseek.com/v1", + http_client_provider_factory=shared_http_client_provider_factory +) +# ... + +# 如果需要强制关闭共享的 http client 可以采用如下方式 +# await close_shared_http_clients() +``` + ### 各个平台模型服务的对接方式: #### hunyuan模型调用方式 @@ -204,6 +225,10 @@ LlmAgent( ) ``` +#### 高级用法 + +从版本 `1.1.10`之后 AnthropicModel 支持传入共享的 http client 来解决连接复用的场景,当前的 OpenAIModel 默认每次都会创建临时的 http client 去访问模型服务;参考:OpenAIModel 章节中的高级用法 + ## LiteLLMModel 随着多个大模型供应商的出现,一些供应商定义了各自的 API 规范。目前,框架已接入 OpenAI 和 Anthropic 的 API(如上文所述),然而,不同供应商在实例化方式和配置项上存在差异,开发者在切换供应商时往往需要修改大量代码,增加了切换成本。 为了解决这一问题,tRPC-Agent 支持通过 [LiteLLM](https://docs.litellm.ai/) 统一接入多厂商模型,使用 **provider/model** 格式(如 `openai/gpt-4o`、`anthropic/claude-3-5-sonnet`、`gemini/gemini-1.5-pro`),一套调用方式切换不同后端。LiteLLMModel 继承 OpenAIModel,仅覆盖 API 调用路径为 `litellm.acompletion`,从而简化了供应商切换的复杂度。 diff --git a/tests/models/test_anthropic_model.py b/tests/models/test_anthropic_model.py index 8b05db11..6ad1210d 100644 --- a/tests/models/test_anthropic_model.py +++ b/tests/models/test_anthropic_model.py @@ -710,13 +710,13 @@ async def test_generate_single_error_raises_and_closes_client(self): model = AnthropicModel(model_name="claude-3-5-sonnet-20241022", api_key="test-key") client = MagicMock() client.messages.create = AsyncMock(side_effect=TimeoutError("timeout")) - client.close = AsyncMock() + model._http_client_provider.close_http_client = AsyncMock() with patch.object(model, "_create_async_client", return_value=client): with pytest.raises(TimeoutError): await model._generate_single({}, LlmRequest(contents=[])) - client.close.assert_awaited_once() + model._http_client_provider.close_http_client.assert_awaited_once_with(client) @pytest.mark.asyncio async def test_generate_async_converts_provider_exception_to_retry_error_response(self): diff --git a/tests/models/test_anthropic_model_ext.py b/tests/models/test_anthropic_model_ext.py index 8b7e609e..2f02b0d9 100644 --- a/tests/models/test_anthropic_model_ext.py +++ b/tests/models/test_anthropic_model_ext.py @@ -472,11 +472,11 @@ async def test_api_error_raises_and_closes_client(self): model = _model() mock_client = AsyncMock() mock_client.messages.create = AsyncMock(side_effect=RuntimeError("timeout")) - mock_client.close = AsyncMock() + model._http_client_provider.close_http_client = AsyncMock() with patch.object(model, "_create_async_client", return_value=mock_client): with pytest.raises(RuntimeError, match="timeout"): await model._generate_single({}, LlmRequest(contents=[])) - mock_client.close.assert_awaited_once() + model._http_client_provider.close_http_client.assert_awaited_once_with(mock_client) # --------------------------------------------------------------------------- diff --git a/tests/models/test_openai_model.py b/tests/models/test_openai_model.py index 5a05043a..2d4efd9a 100644 --- a/tests/models/test_openai_model.py +++ b/tests/models/test_openai_model.py @@ -8,8 +8,6 @@ from unittest.mock import Mock from unittest.mock import patch -import httpx -import openai import pytest from trpc_agent_sdk.models import LlmRequest from trpc_agent_sdk.models import OpenAIModel @@ -242,23 +240,26 @@ def test_model_type_is_model(self): assert model._type == FilterType.MODEL - def test_create_async_client_uses_custom_http_client_factory(self): - """A custom http_client_factory is passed through to AsyncOpenAI.""" + def test_create_async_client_uses_custom_http_client_provider(self): + """A custom http_client_provider_factory is passed through to AsyncOpenAI.""" shared_http_client = Mock() - http_client_factory = Mock(return_value=shared_http_client) + http_client_provider = Mock() + http_client_provider.create_http_client.return_value = shared_http_client + http_client_provider_factory = Mock(return_value=http_client_provider) model = OpenAIModel( model_name="gpt-4", api_key="test_key", base_url="https://custom.api.com", client_args={"timeout": 30}, - http_client_factory=http_client_factory, + http_client_provider_factory=http_client_provider_factory, ) with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai: client = model._create_async_client() assert client is mock_async_openai.return_value - http_client_factory.assert_called_once_with() + http_client_provider_factory.assert_called_once_with() + http_client_provider.create_http_client.assert_called_once_with() mock_async_openai.assert_called_once_with( api_key="test_key", max_retries=0, @@ -268,39 +269,130 @@ def test_create_async_client_uses_custom_http_client_factory(self): http_client=shared_http_client, ) - def test_create_async_client_default_factory_reuses_shared_http_client(self): - """Default factory should reuse one shared httpx.AsyncClient across model calls.""" - from trpc_agent_sdk.models import _openai_model + def test_create_async_client_default_factory_reuses_loop_local_http_client(self): + """Default provider should reuse the same httpx.AsyncClient within one loop.""" + from trpc_agent_sdk.models import _httpx_client + from trpc_agent_sdk.models import shared_http_client_provider_factory - _openai_model._shared_http_client = None shared_http_client = Mock() - model = OpenAIModel(model_name="gpt-4", api_key="test_key") + shared_http_client.is_closed = False + model = OpenAIModel(model_name="gpt-4", api_key="test_key", http_client_provider_factory=shared_http_client_provider_factory) try: - with patch("trpc_agent_sdk.models._openai_model.httpx.AsyncClient", + _httpx_client._shared_http_clients.clear() + with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient", return_value=shared_http_client) as mock_httpx_client: - with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai: - model._create_async_client() - model._create_async_client() + with patch("trpc_agent_sdk.models._httpx_client._get_loop_key", return_value=1): + with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai: + model._create_async_client() + model._create_async_client() finally: - _openai_model._shared_http_client = None + _httpx_client._shared_http_clients.clear() - mock_httpx_client.assert_called_once_with() + mock_httpx_client.assert_called_once_with( + limits=_httpx_client._DEFAULT_HTTP_CLIENT_LIMITS, + timeout=_httpx_client._DEFAULT_HTTP_CLIENT_TIMEOUT, + follow_redirects=True, + ) first_call_kwargs = mock_async_openai.call_args_list[0].kwargs second_call_kwargs = mock_async_openai.call_args_list[1].kwargs assert first_call_kwargs["http_client"] is shared_http_client assert second_call_kwargs["http_client"] is shared_http_client + def test_create_shared_http_client_rebuilds_closed_client(self): + """Closed cached clients should be replaced on the next factory call.""" + from trpc_agent_sdk.models import _httpx_client + + closed_client = Mock() + closed_client.is_closed = True + fresh_client = Mock() + fresh_client.is_closed = False + + try: + _httpx_client._shared_http_clients.clear() + client_key = (1234, 1) + _httpx_client._shared_http_clients[client_key] = closed_client + with patch("trpc_agent_sdk.models._httpx_client._get_client_key", return_value=client_key): + with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient", + return_value=fresh_client) as mock_httpx_client: + assert _httpx_client._create_shared_http_client() is fresh_client + finally: + _httpx_client._shared_http_clients.clear() + + mock_httpx_client.assert_called_once() + + def test_create_shared_http_client_does_not_reuse_across_loop_keys(self): + """Different event loops should get different default httpx clients.""" + from trpc_agent_sdk.models import _httpx_client + + first_client = Mock() + first_client.is_closed = False + second_client = Mock() + second_client.is_closed = False + + try: + _httpx_client._shared_http_clients.clear() + with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient", + side_effect=[first_client, second_client]) as mock_httpx_client: + with patch("trpc_agent_sdk.models._httpx_client._get_client_key", side_effect=[(1234, 1), (1234, 2)]): + assert _httpx_client._create_shared_http_client() is first_client + assert _httpx_client._create_shared_http_client() is second_client + finally: + _httpx_client._shared_http_clients.clear() + + assert mock_httpx_client.call_count == 2 + + def test_create_shared_http_client_does_not_reuse_across_process_keys(self): + """Different process keys should get different default httpx clients.""" + from trpc_agent_sdk.models import _httpx_client + + parent_client = Mock() + parent_client.is_closed = False + child_client = Mock() + child_client.is_closed = False + + try: + _httpx_client._shared_http_clients.clear() + with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient", + side_effect=[parent_client, child_client]) as mock_httpx_client: + with patch("trpc_agent_sdk.models._httpx_client._get_client_key", + side_effect=[(1234, 1), (5678, 1)]): + assert _httpx_client._create_shared_http_client() is parent_client + assert _httpx_client._create_shared_http_client() is child_client + finally: + _httpx_client._shared_http_clients.clear() + + assert mock_httpx_client.call_count == 2 + + def test_reset_shared_http_clients_after_fork_clears_cache_and_rebuilds_lock(self): + """Fork child reset should drop inherited clients and replace inherited locks.""" + from trpc_agent_sdk.models import _httpx_client + + inherited_client = Mock() + old_lock = _httpx_client._shared_http_clients_lock + + try: + _httpx_client._shared_http_clients[(1234, 1)] = inherited_client + + _httpx_client._reset_shared_http_clients_after_fork() + + assert _httpx_client._shared_http_clients == {} + assert _httpx_client._shared_http_clients_lock is not old_lock + finally: + _httpx_client._shared_http_clients.clear() + def test_create_async_client_overwrites_stale_client_args_http_client(self): - """Factory owns http_client injection even if client_args already has one.""" + """Provider owns http_client injection even if client_args already has one.""" stale_http_client = Mock() fresh_http_client = Mock() - http_client_factory = Mock(return_value=fresh_http_client) + http_client_provider = Mock() + http_client_provider.create_http_client.return_value = fresh_http_client + http_client_provider_factory = Mock(return_value=http_client_provider) model = OpenAIModel( model_name="gpt-4", api_key="test_key", client_args={"http_client": stale_http_client, "timeout": 30}, - http_client_factory=http_client_factory, + http_client_provider_factory=http_client_provider_factory, ) with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai: @@ -354,15 +446,19 @@ async def test_generate_async_simple_text_response(self): @pytest.mark.asyncio async def test_generate_async_validation_failure(self): - """Test generate_async converts validation failures to error responses.""" + """Test generate_async returns an error response on invalid request.""" model = OpenAIModel(model_name="gpt-4", api_key="test_key") + # Empty contents request = LlmRequest(contents=[], config=None, tools_dict={}) - responses = [response async for response in model.generate_async(request, stream=False)] + responses = [] + async for response in model.generate_async(request, stream=False): + responses.append(response) + assert len(responses) == 1 assert responses[0].error_code == "API_ERROR" - assert "At least one content is required" in (responses[0].error_message or "") + assert "At least one content is required" in responses[0].error_message @pytest.mark.asyncio async def test_generate_async_with_config_parameters(self): @@ -849,47 +945,3 @@ def test_null_prompt_tokens_details_does_not_crash(self): } meta = OpenAIModel._build_usage_metadata(usage_data) assert meta.cache_read_input_tokens is None - - -class _RetryTestError(Exception): - - def __init__(self, status_code=None, headers=None): - super().__init__(f"status {status_code}" if status_code is not None else "retry test") - if status_code is not None: - self.status_code = status_code - if headers is not None: - self.response = type("Resp", (), {"headers": headers})() - - -class TestOpenAIModelRetryHooks: - - def _model(self): - return OpenAIModel(model_name="gpt-4", api_key="test_key") - - def test_x_should_retry_header_has_priority(self): - model = self._model() - assert model._get_model_retry_info(_RetryTestError(400, {"x-should-retry": "true"})).should_retry is True - assert model._get_model_retry_info(_RetryTestError(500, {"x-should-retry": "false"})).should_retry is False - - @pytest.mark.parametrize("status_code", [408, 409, 429, 500, 503]) - def test_retryable_status_codes(self, status_code): - assert self._model()._get_model_retry_info(_RetryTestError(status_code)).should_retry is True - - @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 499]) - def test_non_retryable_status_codes(self, status_code): - assert self._model()._get_model_retry_info(_RetryTestError(status_code)).should_retry is False - - def test_timeout_exception_retried(self): - request = httpx.Request("GET", "https://example.com") - assert self._model()._get_model_retry_info(httpx.TimeoutException("timeout", request=request)).should_retry is True - - def test_openai_error_not_retried_without_status_decision(self): - assert self._model()._get_model_retry_info(openai.OpenAIError("boom")).should_retry is False - - def test_other_exception_retried(self): - assert self._model()._get_model_retry_info(ValueError("boom")).should_retry is True - - def test_retry_after_extracted_from_headers(self): - info = self._model()._get_model_retry_info(_RetryTestError(429, {"retry-after": "3"})) - assert info.should_retry is True - assert info.retry_after == 3.0 diff --git a/trpc_agent_sdk/agents/_base_agent.py b/trpc_agent_sdk/agents/_base_agent.py index 48b1eb9a..90ecef95 100644 --- a/trpc_agent_sdk/agents/_base_agent.py +++ b/trpc_agent_sdk/agents/_base_agent.py @@ -257,8 +257,8 @@ async def run_async( - Actions """ from trpc_agent_sdk.telemetry import report_invoke_agent - from trpc_agent_sdk.telemetry._trace import tracer - from trpc_agent_sdk.telemetry._trace import trace_agent + from trpc_agent_sdk.telemetry import tracer + from trpc_agent_sdk.telemetry import trace_agent # Manually propagate span context using attach/detach instead of # start_as_current_span. This ensures child spans (call_llm, execute_tool, diff --git a/trpc_agent_sdk/models/__init__.py b/trpc_agent_sdk/models/__init__.py index 87ea42f1..fc9c8966 100644 --- a/trpc_agent_sdk/models/__init__.py +++ b/trpc_agent_sdk/models/__init__.py @@ -45,6 +45,10 @@ from ._openai_model import OpenAIModel from ._openai_model import ToolCall from ._openai_model import ToolKey +from ._httpx_client import close_shared_http_clients +from ._httpx_client import temporary_http_client_provider_factory +from ._httpx_client import shared_http_client_provider_factory +from ._httpx_client import HttpClientProviderFactory from ._registry import ModelRegistry from ._registry import register_model @@ -85,6 +89,10 @@ "OpenAIModel", "ToolCall", "ToolKey", + "close_shared_http_clients", + "temporary_http_client_provider_factory", + "shared_http_client_provider_factory", + "HttpClientProviderFactory", "ModelRegistry", "register_model", ] diff --git a/trpc_agent_sdk/models/_anthropic_model.py b/trpc_agent_sdk/models/_anthropic_model.py index 85b84cfb..ed13d416 100644 --- a/trpc_agent_sdk/models/_anthropic_model.py +++ b/trpc_agent_sdk/models/_anthropic_model.py @@ -38,6 +38,9 @@ from ._llm_model import LLMModel from ._llm_request import LlmRequest from ._llm_response import LlmResponse +from ._httpx_client import BaseHttpClientProvider +from ._httpx_client import HttpClientProviderFactory +from ._httpx_client import shared_http_client_provider_factory from ._registry import register_model _EPHEMERAL = "ephemeral" @@ -154,29 +157,36 @@ def __init__( model_name: str, filters_name: Optional[list[str]] = None, generate_content_config: Optional[GenerateContentConfig] = None, + http_client_provider_factory: HttpClientProviderFactory = shared_http_client_provider_factory, **kwargs, ): super().__init__(model_name, filters_name, **kwargs) # Extract Anthropic-specific config self.client_args = kwargs.get(const.CLIENT_ARGS, {}) - + # Allow callers to inject a tuned httpx client so the underlying openai.AsyncOpenAI honors connection-pool + # settings such as keepalive_expiry / max_keepalive_connections (avoids reusing stale + # keep-alive sockets that gateways close earlier than httpx). + http_client_provider_factory = http_client_provider_factory or shared_http_client_provider_factory + self._http_client_provider: BaseHttpClientProvider = http_client_provider_factory() # Default generation config that can be overridden per request self.generate_content_config = generate_content_config - def _create_async_client(self): + def _create_async_client(self) -> AsyncAnthropic: """Create a new async client instance.""" # Disable httpx logging to prevent HTTP request logs import logging logging.getLogger("httpx").setLevel(logging.WARNING) + client_args = self.client_args.copy() + client_args['http_client'] = self._http_client_provider.create_http_client() return AsyncAnthropic( api_key=self._api_key, max_retries=0, # disable retries base_url=self._base_url if self._base_url else None, - **self.client_args, + **client_args, ) def is_retriable_status_code(self, status_code: int) -> Optional[bool]: @@ -555,7 +565,7 @@ async def _generate_single( response = await client.messages.create(**api_params) return self._message_to_llm_response(response) finally: - await client.close() + await self._http_client_provider.close_http_client(client) async def _generate_stream( self, @@ -683,7 +693,7 @@ async def _generate_stream( logger.error("Error in streaming response", exc_info=True) raise finally: - await client.close() + await self._http_client_provider.close_http_client(client) def _apply_prompt_cache(self, api_params: Dict[str, Any], ctx: InvocationContext | None) -> None: """Inject Anthropic native cache_control breakpoints (opt-in, no-op when disabled).""" diff --git a/trpc_agent_sdk/models/_httpx_client.py b/trpc_agent_sdk/models/_httpx_client.py new file mode 100644 index 00000000..9db7e2d1 --- /dev/null +++ b/trpc_agent_sdk/models/_httpx_client.py @@ -0,0 +1,144 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""HTTPX client implementation module. + +This module provides the HTTPX client implementation for TRPC Agent framework. +""" + +import httpx +import asyncio +import threading +import inspect +import os +from abc import ABC +from abc import abstractmethod +from typing import Callable +from typing import Optional +from typing import Any +from typing_extensions import override + +_DEFAULT_HTTP_CLIENT_LIMITS = httpx.Limits( + max_connections=1000, + max_keepalive_connections=100, + keepalive_expiry=30.0, +) +_DEFAULT_HTTP_CLIENT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=5.0) + + +class BaseHttpClientProvider(ABC): + """Provider for HTTP clients.""" + + @abstractmethod + def create_http_client(self) -> httpx.AsyncClient: + """Create an HTTP client.""" + raise NotImplementedError("Subclasses must implement this method") + + @abstractmethod + async def close_http_client(self, client: Any) -> None: + """Close an HTTP client.""" + raise NotImplementedError("Subclasses must implement this method") + + +class TemporaryHttpClientProvider(BaseHttpClientProvider): + """Provider for temporary HTTP clients.""" + + @override + def create_http_client(self) -> Optional[httpx.AsyncClient]: + """Create a temporary HTTP client.""" + return None + + @override + async def close_http_client(self, client: Any) -> None: + """Close a temporary HTTP client.""" + close_method = getattr(client, "close", None) + if callable(close_method): + result = close_method() + if inspect.isawaitable(result): + await result + + +_shared_http_clients: dict[tuple[int, int], httpx.AsyncClient] = {} +_shared_http_clients_lock: threading.RLock = threading.RLock() + + +def _get_loop_key() -> int: + """Return a cache key for the current event loop, or a process-local fallback.""" + try: + return id(asyncio.get_running_loop()) + except RuntimeError: + return 0 + + +def _get_client_key() -> tuple[int, int]: + """Return a process-local and loop-local cache key for shared HTTP clients.""" + return os.getpid(), _get_loop_key() + + +def _reset_shared_http_clients_after_fork() -> None: + """Drop inherited clients and recreate the lock in a forked child process.""" + global _shared_http_clients_lock + _shared_http_clients.clear() + _shared_http_clients_lock = threading.RLock() + + +if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=_reset_shared_http_clients_after_fork) + + +def _create_shared_http_client() -> httpx.AsyncClient: + """Return a loop-local shared HTTP client with bounded keep-alive reuse. + + Returns: + A loop-local shared HTTP client with bounded keep-alive reuse. + """ + client_key = _get_client_key() + with _shared_http_clients_lock: + client = _shared_http_clients.get(client_key) + if client is None or client.is_closed: + client = httpx.AsyncClient( + limits=_DEFAULT_HTTP_CLIENT_LIMITS, + timeout=_DEFAULT_HTTP_CLIENT_TIMEOUT, + follow_redirects=True, + ) + _shared_http_clients[client_key] = client + return client + + +class SharedHttpClientProvider(BaseHttpClientProvider): + """Provider for shared HTTP clients.""" + + @override + def create_http_client(self) -> Optional[httpx.AsyncClient]: + """Create a shared HTTP client.""" + return _create_shared_http_client() + + @override + async def close_http_client(self, client: Any) -> None: + """Close a shared HTTP client.""" + return None + + +HttpClientProviderFactory = Callable[[], BaseHttpClientProvider] + + +def temporary_http_client_provider_factory() -> BaseHttpClientProvider: + """Provider for temporary HTTP clients.""" + return TemporaryHttpClientProvider() + + +def shared_http_client_provider_factory() -> BaseHttpClientProvider: + """Provider for shared HTTP clients.""" + return SharedHttpClientProvider() + + +async def close_shared_http_clients() -> None: + """Close HTTP clients created by the default HTTP client factory.""" + with _shared_http_clients_lock: + clients = list(_shared_http_clients.values()) + _shared_http_clients.clear() + for client in clients: + if not client.is_closed: + await client.aclose() diff --git a/trpc_agent_sdk/models/_openai_model.py b/trpc_agent_sdk/models/_openai_model.py index 1e33f8e2..d07ed688 100644 --- a/trpc_agent_sdk/models/_openai_model.py +++ b/trpc_agent_sdk/models/_openai_model.py @@ -19,7 +19,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Callable from typing_extensions import override import httpx @@ -43,6 +42,9 @@ from ._llm_request import LlmRequest from ._llm_response import LlmResponse from ._registry import register_model +from ._httpx_client import BaseHttpClientProvider +from ._httpx_client import HttpClientProviderFactory +from ._httpx_client import temporary_http_client_provider_factory from .openai_adapter import get_openai_adapter from .tool_prompt import ToolPromptFactory from .tool_prompt import get_factory @@ -108,19 +110,6 @@ class ApiParamsKey(str, Enum): PROMPT_CACHE_RETENTION = "prompt_cache_retention" -HttpClientFactory = Callable[[], httpx.AsyncClient] - -_shared_http_client: httpx.AsyncClient | None = None - - -def default_http_client_factory() -> httpx.AsyncClient: - """Create a default HTTP client.""" - global _shared_http_client - if _shared_http_client is None: - _shared_http_client = httpx.AsyncClient() - return _shared_http_client - - @register_model(model_name="OpenAIModel", supported_models=[r"gpt-.*", r"o1-.*", r"deepseek-.*", r"hy3-.*"]) class OpenAIModel(LLMModel): """OpenAI model implementation using the abstract model interface. @@ -178,6 +167,7 @@ def __init__( add_tools_to_prompt: bool = False, tool_prompt: str = "xml", generate_content_config: Optional[GenerateContentConfig] = None, + http_client_provider_factory: HttpClientProviderFactory = temporary_http_client_provider_factory, **kwargs, ): super().__init__(model_name, filters_name, **kwargs) @@ -186,11 +176,9 @@ def __init__( # Extract OpenAI-specific config self.organization: str = kwargs.get(const.ORGANIZATION, "") self.client_args = kwargs.get(const.CLIENT_ARGS, {}) - # Allow callers to inject a tuned httpx client so the underlying openai.AsyncOpenAI honors connection-pool - # settings such as keepalive_expiry / max_keepalive_connections (avoids reusing stale - # keep-alive sockets that gateways close earlier than httpx). - self._http_client_factory: HttpClientFactory = kwargs.pop("http_client_factory", None) - self._http_client_factory = self._http_client_factory or default_http_client_factory + # Allow callers to inject a tuned httpx client + http_client_provider_factory = http_client_provider_factory or temporary_http_client_provider_factory + self._http_client_provider: BaseHttpClientProvider = http_client_provider_factory() # Tool prompt configuration self.add_tools_to_prompt = add_tools_to_prompt @@ -246,7 +234,7 @@ def _create_async_client(self) -> openai.AsyncOpenAI: logging.getLogger("httpx").setLevel(logging.WARNING) client_args = self.client_args.copy() - client_args['http_client'] = self._http_client_factory() + client_args['http_client'] = self._http_client_provider.create_http_client() return openai.AsyncOpenAI( api_key=self._api_key, @@ -1123,23 +1111,25 @@ async def _generate_single(self, if http_options is None: http_options = {} client = self._create_async_client() - response = await client.chat.completions.create(**api_params, **http_options) - response_dict: dict = response.model_dump() + try: + response = await client.chat.completions.create(**api_params, **http_options) + response_dict: dict = response.model_dump() - # Check if the response contains valid text content or tool calls - has_text_content = self._verify_text_content_in_openai_message_response(response_dict) - has_tool_calls = False + # Check if the response contains valid text content or tool calls + has_text_content = self._verify_text_content_in_openai_message_response(response_dict) + has_tool_calls = False - # Check for tool calls - choices: list[dict] = response_dict.get(const.CHOICES, [{}]) - if choices and choices[0].get(const.MESSAGE, {}).get(const.TOOL_CALLS): - has_tool_calls = True + # Check for tool calls + choices: list[dict] = response_dict.get(const.CHOICES, [{}]) + if choices and choices[0].get(const.MESSAGE, {}).get(const.TOOL_CALLS): + has_tool_calls = True - # Create response with content if we have text or tool calls - if has_text_content or has_tool_calls: - return self._create_response_with_content(response_dict) - else: + # Create response with content if we have text or tool calls + if has_text_content or has_tool_calls: + return self._create_response_with_content(response_dict) return self._create_response_without_content(response_dict) + finally: + await self._http_client_provider.close_http_client(client) def _convert_tools_to_openai_format(self, tools: List[Tool]) -> List[Dict[str, Any]]: """Convert Google GenAI tools format to OpenAI tools format. @@ -1580,199 +1570,207 @@ async def _generate_stream(self, client = self._create_async_client() logger.debug("openai invoke with params: %s", api_params) - response = await client.chat.completions.create(**api_params, **http_options) - if response is None: - raise ValueError("Empty response from API") + try: + response = await client.chat.completions.create(**api_params, **http_options) + if response is None: + raise ValueError("Empty response from API") - async for chunk in response: - if chunk is None: - continue + async for chunk in response: + if chunk is None: + continue - chunk_dict: dict = chunk.model_dump() - logger.debug("🔥 RAW LLM CHUNK: %s", json.dumps(chunk_dict, ensure_ascii=False)) - - # Capture response ID from chunk (only set once from first chunk that has it) - if response_id is None and chunk_dict.get("id"): - response_id = chunk_dict.get("id") - - # Check for thinking events first - if self._is_thinking_event(chunk_dict): - thinking_state = self._get_thinking_state(chunk_dict) - if thinking_state == 0: - # Start thinking - is_thinking = True - elif thinking_state == 2: - # End thinking - is_thinking = False - # Handle thinking events - these are metadata about thinking state - # We can log them but don't need to yield them as content - continue + chunk_dict: dict = chunk.model_dump() + logger.debug("🔥 RAW LLM CHUNK: %s", json.dumps(chunk_dict, ensure_ascii=False)) + + # Capture response ID from chunk (only set once from first chunk that has it) + if response_id is None and chunk_dict.get("id"): + response_id = chunk_dict.get("id") + + # Check for thinking events first + if self._is_thinking_event(chunk_dict): + thinking_state = self._get_thinking_state(chunk_dict) + if thinking_state == 0: + # Start thinking + is_thinking = True + elif thinking_state == 2: + # End thinking + is_thinking = False + # Handle thinking events - these are metadata about thinking state + # We can log them but don't need to yield them as content + continue - # Verify if the chunk contains valid text content - if not self._verify_text_content_in_delta_response(chunk_dict): - # Process chunk without content (this is where tool calls are streamed) - _, usage, delta_arguments = self._process_chunk_without_content(chunk_dict, accumulated_tool_calls) + # Verify if the chunk contains valid text content + if not self._verify_text_content_in_delta_response(chunk_dict): + # Process chunk without content (this is where tool calls are streamed) + _, usage, delta_arguments = self._process_chunk_without_content(chunk_dict, accumulated_tool_calls) + + if usage: + last_usage = usage + + # If streaming tool call arguments is enabled, yield partial tool call events + # delta_arguments being non-empty means tool calls were processed + if streaming_tool_names and delta_arguments and accumulated_tool_calls: + # Yield streaming tool call event with delta arguments + # Only for tools in streaming_tool_names + streaming_event = self._create_streaming_tool_call_response(accumulated_tool_calls, + delta_arguments, + streaming_tool_names) + if streaming_event: + yield streaming_event - if usage: - last_usage = usage - - # If streaming tool call arguments is enabled, yield partial tool call events - # delta_arguments being non-empty means tool calls were processed - if streaming_tool_names and delta_arguments and accumulated_tool_calls: - # Yield streaming tool call event with delta arguments - # Only for tools in streaming_tool_names - streaming_event = self._create_streaming_tool_call_response(accumulated_tool_calls, delta_arguments, - streaming_tool_names) - if streaming_event: - yield streaming_event + continue - continue + # Process chunk with valid content + choices = chunk_dict.get(const.CHOICES, []) + if not choices: + continue # Skip if no choices available + choice: dict[str, dict] = choices[0] + delta = choice[const.DELTA] + + # Handle reasoning content (thinking content) first + if delta.get(const.REASONING_CONTENT): + reasoning_content = delta.get(const.REASONING_CONTENT) + if reasoning_content is not None: + partial_text = reasoning_content + if (tool_prompt and streaming_text_filter_state is not None + and self._adapter.should_filter_reasoning_text()): + reasoning_filter_state = streaming_text_filter_state["reasoning"] + partial_text = self._adapter.filter_streaming_text(reasoning_content, + reasoning_filter_state) + if not partial_text: + continue - # Process chunk with valid content - choices = chunk_dict.get(const.CHOICES, []) - if not choices: - continue # Skip if no choices available - choice: dict[str, dict] = choices[0] - delta = choice[const.DELTA] - - # Handle reasoning content (thinking content) first - if delta.get(const.REASONING_CONTENT): - reasoning_content = delta.get(const.REASONING_CONTENT) - if reasoning_content is not None: - partial_text = reasoning_content - if (tool_prompt and streaming_text_filter_state is not None - and self._adapter.should_filter_reasoning_text()): - reasoning_filter_state = streaming_text_filter_state["reasoning"] - partial_text = self._adapter.filter_streaming_text(reasoning_content, reasoning_filter_state) - if not partial_text: - continue - - # Reasoning content is always thinking content - thought_content += partial_text - - # Set thought flag to True for reasoning content - content_part = Part.from_text(text=partial_text) - content_part.thought = True + # Reasoning content is always thinking content + thought_content += partial_text + + # Set thought flag to True for reasoning content + content_part = Part.from_text(text=partial_text) + content_part.thought = True + + partial_content = Content(parts=[content_part], role=const.MODEL) + yield LlmResponse(content=partial_content, + partial=True, + response_id=response_id, + custom_metadata={const.CHUNK: chunk_dict}) + + # Handle regular content + if delta.get(const.CONTENT): + content = delta.get(const.CONTENT) + if content is not None: + if not is_thinking: + accumulated_content += content + else: + thought_content += content - partial_content = Content(parts=[content_part], role=const.MODEL) - yield LlmResponse(content=partial_content, - partial=True, - response_id=response_id, - custom_metadata={const.CHUNK: chunk_dict}) - - # Handle regular content - if delta.get(const.CONTENT): - content = delta.get(const.CONTENT) - if content is not None: - if not is_thinking: - accumulated_content += content - else: - thought_content += content + partial_text = content + if tool_prompt and streaming_text_filter_state is not None: + content_filter_state = streaming_text_filter_state["content"] + partial_text = self._adapter.filter_streaming_text(content, content_filter_state) + if not partial_text: + continue - partial_text = content - if tool_prompt and streaming_text_filter_state is not None: - content_filter_state = streaming_text_filter_state["content"] - partial_text = self._adapter.filter_streaming_text(content, content_filter_state) - if not partial_text: - continue + # Set thought flag based on current thinking state + content_part = Part.from_text(text=partial_text) + content_part.thought = is_thinking - # Set thought flag based on current thinking state - content_part = Part.from_text(text=partial_text) - content_part.thought = is_thinking + partial_content = Content(parts=[content_part], role=const.MODEL) + yield LlmResponse(content=partial_content, + partial=True, + response_id=response_id, + custom_metadata={const.CHUNK: chunk_dict}) - partial_content = Content(parts=[content_part], role=const.MODEL) - yield LlmResponse(content=partial_content, - partial=True, - response_id=response_id, - custom_metadata={const.CHUNK: chunk_dict}) + # Handle usage + usage = self._process_usage(chunk_dict) + if usage: + last_usage = usage - # Handle usage - usage = self._process_usage(chunk_dict) - if usage: - last_usage = usage - - if tool_prompt and streaming_text_filter_state is not None: - if self._adapter.should_filter_reasoning_text(): - flushed_reasoning_text = self._adapter.flush_streaming_text(streaming_text_filter_state["reasoning"]) - if flushed_reasoning_text: - thought_content += flushed_reasoning_text - content_part = Part.from_text(text=flushed_reasoning_text) - content_part.thought = True + if tool_prompt and streaming_text_filter_state is not None: + if self._adapter.should_filter_reasoning_text(): + flushed_reasoning_text = self._adapter.flush_streaming_text( + streaming_text_filter_state["reasoning"]) + if flushed_reasoning_text: + thought_content += flushed_reasoning_text + content_part = Part.from_text(text=flushed_reasoning_text) + content_part.thought = True + partial_content = Content(parts=[content_part], role=const.MODEL) + yield LlmResponse(content=partial_content, + partial=True, + response_id=response_id, + custom_metadata={"stream_filter_flushed": "reasoning"}) + + flushed_content_text = self._adapter.flush_streaming_text(streaming_text_filter_state["content"]) + if flushed_content_text: + content_part = Part.from_text(text=flushed_content_text) + content_part.thought = is_thinking partial_content = Content(parts=[content_part], role=const.MODEL) yield LlmResponse(content=partial_content, partial=True, response_id=response_id, - custom_metadata={"stream_filter_flushed": "reasoning"}) - - flushed_content_text = self._adapter.flush_streaming_text(streaming_text_filter_state["content"]) - if flushed_content_text: - content_part = Part.from_text(text=flushed_content_text) - content_part.thought = is_thinking - partial_content = Content(parts=[content_part], role=const.MODEL) - yield LlmResponse(content=partial_content, - partial=True, - response_id=response_id, - custom_metadata={"stream_filter_flushed": "content"}) + custom_metadata={"stream_filter_flushed": "content"}) - # Yield final complete response - final_content = None - - parts = [] + # Yield final complete response + final_content = None - if thought_content: - logger.debug("Final accumulated thought content: %s...", thought_content[:200]) - content_part = Part.from_text(text=thought_content) - content_part.thought = True - parts.append(content_part) + parts = [] - # Parse function calls from final accumulated content if add_tools_to_prompt is enabled - complete_tool_calls = self._create_complete_tool_calls(accumulated_tool_calls) - if tool_prompt and accumulated_content and not complete_tool_calls: - try: - parsed_function_calls = self._adapter.parse_tool_prompt_function_calls(accumulated_content, tool_prompt) - if parsed_function_calls: - # Convert FunctionCall objects to ToolCall objects - complete_tool_calls = [] - for func_call in parsed_function_calls: - tool_call = ToolCall(id=f"call_{uuid.uuid4().hex[:24]}", - name=func_call.name, - arguments=func_call.args) - complete_tool_calls.append(tool_call) - logger.debug("Parsed %s function calls from final accumulated content", len(complete_tool_calls)) - except Exception as ex: # pylint: disable=broad-except - logger.warning("Failed to parse function calls from final accumulated content: %s", ex) + if thought_content: + logger.debug("Final accumulated thought content: %s...", thought_content[:200]) + content_part = Part.from_text(text=thought_content) + content_part.thought = True + parts.append(content_part) - # Add text content if present - if accumulated_content and not complete_tool_calls: - logger.debug("Final accumulated regular content: %s...", accumulated_content[:200]) - content_part = Part.from_text(text=accumulated_content) - content_part.thought = False # Final accumulated content represents the answer, not thinking - parts.append(content_part) - - if complete_tool_calls: - for tool_call in complete_tool_calls: - # Create Part with function_call using the from_function_call method - function_part = Part.from_function_call(name=tool_call.name, args=tool_call.arguments) - # Set the id if available - if tool_call.id: - function_part.function_call.id = tool_call.id # type: ignore - self._set_part_thought_signature(function_part, tool_call.thought_signature) - parts.append(function_part) - - # Create final content with parts - if parts: - final_content = Content(parts=parts, role=const.MODEL) - - # Convert usage to the expected format for LlmResponse - final_usage = None - if last_usage: - # Create a compatible usage metadata object - final_usage = last_usage # Use the existing usage object for now - - yield LlmResponse( - content=final_content, - usage_metadata=final_usage, - partial=False, - response_id=response_id, - custom_metadata={"stream_complete": True}, - ) + # Parse function calls from final accumulated content if add_tools_to_prompt is enabled + complete_tool_calls = self._create_complete_tool_calls(accumulated_tool_calls) + if tool_prompt and accumulated_content and not complete_tool_calls: + try: + parsed_function_calls = self._adapter.parse_tool_prompt_function_calls( + accumulated_content, tool_prompt) + if parsed_function_calls: + # Convert FunctionCall objects to ToolCall objects + complete_tool_calls = [] + for func_call in parsed_function_calls: + tool_call = ToolCall(id=f"call_{uuid.uuid4().hex[:24]}", + name=func_call.name, + arguments=func_call.args) + complete_tool_calls.append(tool_call) + logger.debug("Parsed %s function calls from final accumulated content", + len(complete_tool_calls)) + except Exception as ex: # pylint: disable=broad-except + logger.warning("Failed to parse function calls from final accumulated content: %s", ex) + + # Add text content if present + if accumulated_content and not complete_tool_calls: + logger.debug("Final accumulated regular content: %s...", accumulated_content[:200]) + content_part = Part.from_text(text=accumulated_content) + content_part.thought = False # Final accumulated content represents the answer, not thinking + parts.append(content_part) + + if complete_tool_calls: + for tool_call in complete_tool_calls: + # Create Part with function_call using the from_function_call method + function_part = Part.from_function_call(name=tool_call.name, args=tool_call.arguments) + # Set the id if available + if tool_call.id: + function_part.function_call.id = tool_call.id # type: ignore + self._set_part_thought_signature(function_part, tool_call.thought_signature) + parts.append(function_part) + + # Create final content with parts + if parts: + final_content = Content(parts=parts, role=const.MODEL) + + # Convert usage to the expected format for LlmResponse + final_usage = None + if last_usage: + # Create a compatible usage metadata object + final_usage = last_usage # Use the existing usage object for now + + yield LlmResponse( + content=final_content, + usage_metadata=final_usage, + partial=False, + response_id=response_id, + custom_metadata={"stream_complete": True}, + ) + finally: + await self._http_client_provider.close_http_client(client) diff --git a/trpc_agent_sdk/runners.py b/trpc_agent_sdk/runners.py index 5d48f26d..ffc40914 100644 --- a/trpc_agent_sdk/runners.py +++ b/trpc_agent_sdk/runners.py @@ -35,8 +35,8 @@ from trpc_agent_sdk.sessions import BaseSessionService from trpc_agent_sdk.sessions import Session from trpc_agent_sdk.telemetry import tracer -from trpc_agent_sdk.telemetry._trace import trace_cancellation -from trpc_agent_sdk.telemetry._trace import trace_runner +from trpc_agent_sdk.telemetry import trace_cancellation +from trpc_agent_sdk.telemetry import trace_runner from trpc_agent_sdk.tools import BaseToolSet from trpc_agent_sdk.types import Content from trpc_agent_sdk.types import Part diff --git a/trpc_agent_sdk/telemetry/_metrics.py b/trpc_agent_sdk/telemetry/_metrics.py index 7e0cc6b2..f7eb462b 100644 --- a/trpc_agent_sdk/telemetry/_metrics.py +++ b/trpc_agent_sdk/telemetry/_metrics.py @@ -5,7 +5,7 @@ # tRPC-Agent-Python is licensed under Apache-2.0. """OTel-native ``gen_ai.*`` metrics for the TRPC Agent framework. -Mirrors :mod:`trpc_agent_sdk.telemetry._trace`: module-level instruments plus +Mirrors :mod:`trpc_agent_sdk.telemetry`: module-level instruments plus ``report_*`` free functions. Backends fan out via the installed ``MeterProvider`` and route by ``gen_ai.operation.name``. """