Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/mkdocs/en/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@ model = OpenAIModel(
)
```

#### Advanced Usage

Since version `1.1.10`, `OpenAIModel` supports passing a shared HTTP client provider to enable connection reuse. By default, `OpenAIModel` creates a temporary HTTP client for each model-service request. If you want to reuse connections, use the following configuration:

```python
from trpc_agent_sdk.models import OpenAIModel
from trpc_agent_sdk.models import shared_http_client_provider_factory
# from trpc_agent_sdk.models import close_shared_http_clients

model = OpenAIModel(
model_name="deepseek-chat",
api_key="your-api-key",
base_url="https://api.deepseek.com/v1",
http_client_provider_factory=shared_http_client_provider_factory,
)
# ...

# If you need to force-close the shared HTTP clients, use:
# await close_shared_http_clients()
```

### Integration with Various Platform Model Services:

#### Hunyuan Model Invocation
Expand Down Expand Up @@ -204,6 +225,10 @@ LlmAgent(
)
```

#### Advanced Usage

Since version `1.1.10`, `AnthropicModel` supports passing a shared HTTP client provider to enable connection reuse. By default, `AnthropicModel` creates a temporary HTTP client for each model-service request. See the Advanced Usage section under `OpenAIModel` for an example.

## LiteLLMModel
As multiple LLM providers have emerged, some have defined their own API specifications. Currently, the framework has integrated OpenAI and Anthropic APIs as described above. However, differences in instantiation methods and configuration options across providers mean that developers often need to modify substantial amounts of code when switching providers, increasing the switching cost.
To address this issue, tRPC-Agent supports unified multi-provider model access through [LiteLLM](https://docs.litellm.ai/), using the **provider/model** format (e.g., `openai/gpt-4o`, `anthropic/claude-3-5-sonnet`, `gemini/gemini-1.5-pro`), enabling switching between different backends with a single invocation pattern. LiteLLMModel inherits from OpenAIModel and only overrides the API call path to `litellm.acompletion`, simplifying the complexity of provider switching.
Expand Down
25 changes: 25 additions & 0 deletions docs/mkdocs/zh/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@ model = OpenAIModel(
)
```

#### 高级用法

从版本 `1.1.10`之后 OpenAIModel 支持传入共享的 http client 来解决连接复用的场景,当前的 OpenAIModel 默认每次都会创建临时的 http client 去访问模型服务;如果期望连接复用可以使用如下的方式

```python
from trpc_agent_sdk.models import OpenAIModel
from trpc_agent_sdk.models import shared_http_client_provider_factory
#from trpc_agent_sdk.models import close_shared_http_clients

model = OpenAIModel(
model_name="deepseek-chat",
api_key="your-api-key",
base_url="https://api.deepseek.com/v1",
http_client_provider_factory=shared_http_client_provider_factory
)
# ...

# 如果需要强制关闭共享的 http client 可以采用如下方式
# await close_shared_http_clients()
```

### 各个平台模型服务的对接方式:

#### hunyuan模型调用方式
Expand Down Expand Up @@ -204,6 +225,10 @@ LlmAgent(
)
```

#### 高级用法

从版本 `1.1.10`之后 AnthropicModel 支持传入共享的 http client 来解决连接复用的场景,当前的 OpenAIModel 默认每次都会创建临时的 http client 去访问模型服务;参考:OpenAIModel 章节中的高级用法

## LiteLLMModel
随着多个大模型供应商的出现,一些供应商定义了各自的 API 规范。目前,框架已接入 OpenAI 和 Anthropic 的 API(如上文所述),然而,不同供应商在实例化方式和配置项上存在差异,开发者在切换供应商时往往需要修改大量代码,增加了切换成本。
为了解决这一问题,tRPC-Agent 支持通过 [LiteLLM](https://docs.litellm.ai/) 统一接入多厂商模型,使用 **provider/model** 格式(如 `openai/gpt-4o`、`anthropic/claude-3-5-sonnet`、`gemini/gemini-1.5-pro`),一套调用方式切换不同后端。LiteLLMModel 继承 OpenAIModel,仅覆盖 API 调用路径为 `litellm.acompletion`,从而简化了供应商切换的复杂度。
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,13 +710,13 @@ async def test_generate_single_error_raises_and_closes_client(self):
model = AnthropicModel(model_name="claude-3-5-sonnet-20241022", api_key="test-key")
client = MagicMock()
client.messages.create = AsyncMock(side_effect=TimeoutError("timeout"))
client.close = AsyncMock()
model._http_client_provider.close_http_client = AsyncMock()

with patch.object(model, "_create_async_client", return_value=client):
with pytest.raises(TimeoutError):
await model._generate_single({}, LlmRequest(contents=[]))

client.close.assert_awaited_once()
model._http_client_provider.close_http_client.assert_awaited_once_with(client)

@pytest.mark.asyncio
async def test_generate_async_converts_provider_exception_to_retry_error_response(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_anthropic_model_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,11 @@ async def test_api_error_raises_and_closes_client(self):
model = _model()
mock_client = AsyncMock()
mock_client.messages.create = AsyncMock(side_effect=RuntimeError("timeout"))
mock_client.close = AsyncMock()
model._http_client_provider.close_http_client = AsyncMock()
with patch.object(model, "_create_async_client", return_value=mock_client):
with pytest.raises(RuntimeError, match="timeout"):
await model._generate_single({}, LlmRequest(contents=[]))
mock_client.close.assert_awaited_once()
model._http_client_provider.close_http_client.assert_awaited_once_with(mock_client)


# ---------------------------------------------------------------------------
Expand Down
188 changes: 120 additions & 68 deletions tests/models/test_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from unittest.mock import Mock
from unittest.mock import patch

import httpx
import openai
import pytest
from trpc_agent_sdk.models import LlmRequest
from trpc_agent_sdk.models import OpenAIModel
Expand Down Expand Up @@ -242,23 +240,26 @@ def test_model_type_is_model(self):

assert model._type == FilterType.MODEL

def test_create_async_client_uses_custom_http_client_factory(self):
"""A custom http_client_factory is passed through to AsyncOpenAI."""
def test_create_async_client_uses_custom_http_client_provider(self):
"""A custom http_client_provider_factory is passed through to AsyncOpenAI."""
shared_http_client = Mock()
http_client_factory = Mock(return_value=shared_http_client)
http_client_provider = Mock()
http_client_provider.create_http_client.return_value = shared_http_client
http_client_provider_factory = Mock(return_value=http_client_provider)
model = OpenAIModel(
model_name="gpt-4",
api_key="test_key",
base_url="https://custom.api.com",
client_args={"timeout": 30},
http_client_factory=http_client_factory,
http_client_provider_factory=http_client_provider_factory,
)

with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai:
client = model._create_async_client()

assert client is mock_async_openai.return_value
http_client_factory.assert_called_once_with()
http_client_provider_factory.assert_called_once_with()
http_client_provider.create_http_client.assert_called_once_with()
mock_async_openai.assert_called_once_with(
api_key="test_key",
max_retries=0,
Expand All @@ -268,39 +269,130 @@ def test_create_async_client_uses_custom_http_client_factory(self):
http_client=shared_http_client,
)

def test_create_async_client_default_factory_reuses_shared_http_client(self):
"""Default factory should reuse one shared httpx.AsyncClient across model calls."""
from trpc_agent_sdk.models import _openai_model
def test_create_async_client_default_factory_reuses_loop_local_http_client(self):
"""Default provider should reuse the same httpx.AsyncClient within one loop."""
from trpc_agent_sdk.models import _httpx_client
from trpc_agent_sdk.models import shared_http_client_provider_factory

_openai_model._shared_http_client = None
shared_http_client = Mock()
model = OpenAIModel(model_name="gpt-4", api_key="test_key")
shared_http_client.is_closed = False
model = OpenAIModel(model_name="gpt-4", api_key="test_key", http_client_provider_factory=shared_http_client_provider_factory)

try:
with patch("trpc_agent_sdk.models._openai_model.httpx.AsyncClient",
_httpx_client._shared_http_clients.clear()
with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient",
return_value=shared_http_client) as mock_httpx_client:
with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai:
model._create_async_client()
model._create_async_client()
with patch("trpc_agent_sdk.models._httpx_client._get_loop_key", return_value=1):
with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai:
model._create_async_client()
model._create_async_client()
finally:
_openai_model._shared_http_client = None
_httpx_client._shared_http_clients.clear()

mock_httpx_client.assert_called_once_with()
mock_httpx_client.assert_called_once_with(
limits=_httpx_client._DEFAULT_HTTP_CLIENT_LIMITS,
timeout=_httpx_client._DEFAULT_HTTP_CLIENT_TIMEOUT,
follow_redirects=True,
)
first_call_kwargs = mock_async_openai.call_args_list[0].kwargs
second_call_kwargs = mock_async_openai.call_args_list[1].kwargs
assert first_call_kwargs["http_client"] is shared_http_client
assert second_call_kwargs["http_client"] is shared_http_client

def test_create_shared_http_client_rebuilds_closed_client(self):
"""Closed cached clients should be replaced on the next factory call."""
from trpc_agent_sdk.models import _httpx_client

closed_client = Mock()
closed_client.is_closed = True
fresh_client = Mock()
fresh_client.is_closed = False

try:
_httpx_client._shared_http_clients.clear()
client_key = (1234, 1)
_httpx_client._shared_http_clients[client_key] = closed_client
with patch("trpc_agent_sdk.models._httpx_client._get_client_key", return_value=client_key):
with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient",
return_value=fresh_client) as mock_httpx_client:
assert _httpx_client._create_shared_http_client() is fresh_client
finally:
_httpx_client._shared_http_clients.clear()

mock_httpx_client.assert_called_once()

def test_create_shared_http_client_does_not_reuse_across_loop_keys(self):
"""Different event loops should get different default httpx clients."""
from trpc_agent_sdk.models import _httpx_client

first_client = Mock()
first_client.is_closed = False
second_client = Mock()
second_client.is_closed = False

try:
_httpx_client._shared_http_clients.clear()
with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient",
side_effect=[first_client, second_client]) as mock_httpx_client:
with patch("trpc_agent_sdk.models._httpx_client._get_client_key", side_effect=[(1234, 1), (1234, 2)]):
assert _httpx_client._create_shared_http_client() is first_client
assert _httpx_client._create_shared_http_client() is second_client
finally:
_httpx_client._shared_http_clients.clear()

assert mock_httpx_client.call_count == 2

def test_create_shared_http_client_does_not_reuse_across_process_keys(self):
"""Different process keys should get different default httpx clients."""
from trpc_agent_sdk.models import _httpx_client

parent_client = Mock()
parent_client.is_closed = False
child_client = Mock()
child_client.is_closed = False

try:
_httpx_client._shared_http_clients.clear()
with patch("trpc_agent_sdk.models._httpx_client.httpx.AsyncClient",
side_effect=[parent_client, child_client]) as mock_httpx_client:
with patch("trpc_agent_sdk.models._httpx_client._get_client_key",
side_effect=[(1234, 1), (5678, 1)]):
assert _httpx_client._create_shared_http_client() is parent_client
assert _httpx_client._create_shared_http_client() is child_client
finally:
_httpx_client._shared_http_clients.clear()

assert mock_httpx_client.call_count == 2

def test_reset_shared_http_clients_after_fork_clears_cache_and_rebuilds_lock(self):
"""Fork child reset should drop inherited clients and replace inherited locks."""
from trpc_agent_sdk.models import _httpx_client

inherited_client = Mock()
old_lock = _httpx_client._shared_http_clients_lock

try:
_httpx_client._shared_http_clients[(1234, 1)] = inherited_client

_httpx_client._reset_shared_http_clients_after_fork()

assert _httpx_client._shared_http_clients == {}
assert _httpx_client._shared_http_clients_lock is not old_lock
finally:
_httpx_client._shared_http_clients.clear()

def test_create_async_client_overwrites_stale_client_args_http_client(self):
"""Factory owns http_client injection even if client_args already has one."""
"""Provider owns http_client injection even if client_args already has one."""
stale_http_client = Mock()
fresh_http_client = Mock()
http_client_factory = Mock(return_value=fresh_http_client)
http_client_provider = Mock()
http_client_provider.create_http_client.return_value = fresh_http_client
http_client_provider_factory = Mock(return_value=http_client_provider)
model = OpenAIModel(
model_name="gpt-4",
api_key="test_key",
client_args={"http_client": stale_http_client, "timeout": 30},
http_client_factory=http_client_factory,
http_client_provider_factory=http_client_provider_factory,
)

with patch("trpc_agent_sdk.models._openai_model.openai.AsyncOpenAI") as mock_async_openai:
Expand Down Expand Up @@ -354,15 +446,19 @@ async def test_generate_async_simple_text_response(self):

@pytest.mark.asyncio
async def test_generate_async_validation_failure(self):
"""Test generate_async converts validation failures to error responses."""
"""Test generate_async returns an error response on invalid request."""
model = OpenAIModel(model_name="gpt-4", api_key="test_key")

# Empty contents
request = LlmRequest(contents=[], config=None, tools_dict={})

responses = [response async for response in model.generate_async(request, stream=False)]
responses = []
async for response in model.generate_async(request, stream=False):
responses.append(response)

assert len(responses) == 1
assert responses[0].error_code == "API_ERROR"
assert "At least one content is required" in (responses[0].error_message or "")
assert "At least one content is required" in responses[0].error_message

@pytest.mark.asyncio
async def test_generate_async_with_config_parameters(self):
Expand Down Expand Up @@ -849,47 +945,3 @@ def test_null_prompt_tokens_details_does_not_crash(self):
}
meta = OpenAIModel._build_usage_metadata(usage_data)
assert meta.cache_read_input_tokens is None


class _RetryTestError(Exception):

def __init__(self, status_code=None, headers=None):
super().__init__(f"status {status_code}" if status_code is not None else "retry test")
if status_code is not None:
self.status_code = status_code
if headers is not None:
self.response = type("Resp", (), {"headers": headers})()


class TestOpenAIModelRetryHooks:

def _model(self):
return OpenAIModel(model_name="gpt-4", api_key="test_key")

def test_x_should_retry_header_has_priority(self):
model = self._model()
assert model._get_model_retry_info(_RetryTestError(400, {"x-should-retry": "true"})).should_retry is True
assert model._get_model_retry_info(_RetryTestError(500, {"x-should-retry": "false"})).should_retry is False

@pytest.mark.parametrize("status_code", [408, 409, 429, 500, 503])
def test_retryable_status_codes(self, status_code):
assert self._model()._get_model_retry_info(_RetryTestError(status_code)).should_retry is True

@pytest.mark.parametrize("status_code", [400, 401, 403, 404, 499])
def test_non_retryable_status_codes(self, status_code):
assert self._model()._get_model_retry_info(_RetryTestError(status_code)).should_retry is False

def test_timeout_exception_retried(self):
request = httpx.Request("GET", "https://example.com")
assert self._model()._get_model_retry_info(httpx.TimeoutException("timeout", request=request)).should_retry is True

def test_openai_error_not_retried_without_status_decision(self):
assert self._model()._get_model_retry_info(openai.OpenAIError("boom")).should_retry is False

def test_other_exception_retried(self):
assert self._model()._get_model_retry_info(ValueError("boom")).should_retry is True

def test_retry_after_extracted_from_headers(self):
info = self._model()._get_model_retry_info(_RetryTestError(429, {"retry-after": "3"}))
assert info.should_retry is True
assert info.retry_after == 3.0
4 changes: 2 additions & 2 deletions trpc_agent_sdk/agents/_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ async def run_async(
- Actions
"""
from trpc_agent_sdk.telemetry import report_invoke_agent
from trpc_agent_sdk.telemetry._trace import tracer
from trpc_agent_sdk.telemetry._trace import trace_agent
from trpc_agent_sdk.telemetry import tracer
from trpc_agent_sdk.telemetry import trace_agent

# Manually propagate span context using attach/detach instead of
# start_as_current_span. This ensures child spans (call_llm, execute_tool,
Expand Down
8 changes: 8 additions & 0 deletions trpc_agent_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from ._openai_model import OpenAIModel
from ._openai_model import ToolCall
from ._openai_model import ToolKey
from ._httpx_client import close_shared_http_clients
from ._httpx_client import temporary_http_client_provider_factory
from ._httpx_client import shared_http_client_provider_factory
from ._httpx_client import HttpClientProviderFactory
from ._registry import ModelRegistry
from ._registry import register_model

Expand Down Expand Up @@ -85,6 +89,10 @@
"OpenAIModel",
"ToolCall",
"ToolKey",
"close_shared_http_clients",
"temporary_http_client_provider_factory",
"shared_http_client_provider_factory",
"HttpClientProviderFactory",
"ModelRegistry",
"register_model",
]
Loading
Loading