diff --git a/agentrun/integration/langchain/model_adapter.py b/agentrun/integration/langchain/model_adapter.py index 8f9e494..6e1a956 100644 --- a/agentrun/integration/langchain/model_adapter.py +++ b/agentrun/integration/langchain/model_adapter.py @@ -2,15 +2,15 @@ 将 CommonModel 包装为 LangChain BaseChatModel。""" -import inspect -import json -from typing import Any, List, Optional +from typing import Any from agentrun.integration.langchain.message_adapter import ( LangChainMessageAdapter, ) from agentrun.integration.utils.adapter import ModelAdapter +_DEEPSEEK_PROVIDER = "deepseek" + class LangChainModelAdapter(ModelAdapter): """LangChain 模型适配器 / LangChain Model Adapter @@ -23,9 +23,45 @@ def __init__(self): def wrap_model(self, common_model: Any) -> Any: """包装 CommonModel 为 LangChain BaseChatModel / LangChain Model Adapter""" - from langchain_openai import ChatOpenAI - info = common_model.get_model_info() # 确保模型可用 + provider = (info.provider or "").lower() + + if provider == _DEEPSEEK_PROVIDER: + return self._create_reasoning_model(info) + return self._create_openai_model(info) + + def _create_reasoning_model(self, info: Any) -> Any: + """创建支持 reasoning_content 的模型(使用 ChatDeepSeek)""" + try: + from langchain_deepseek import ChatDeepSeek + except ImportError as e: + raise ImportError( + "import langchain_deepseek failed. " + "Install it with: pip install 'agentrun-sdk[langchain]' " + "or pip install 'agentrun-sdk[langgraph]'" + ) from e + + return ChatDeepSeek( + name=info.model, + model=info.model, + api_key=info.api_key, + api_base=info.base_url, + default_headers=info.headers, + stream_usage=True, + streaming=True, + ) + + def _create_openai_model(self, info: Any) -> Any: + """创建标准 OpenAI 兼容模型""" + try: + from langchain_openai import ChatOpenAI + except ImportError as e: + raise ImportError( + "import langchain_openai failed. " + "Install it with: pip install 'agentrun-sdk[langchain]' " + "or pip install 'agentrun-sdk[langgraph]'" + ) from e + return ChatOpenAI( name=info.model, api_key=info.api_key, @@ -33,5 +69,5 @@ def wrap_model(self, common_model: Any) -> Any: base_url=info.base_url, default_headers=info.headers, stream_usage=True, - streaming=True, # 启用流式输出以支持 token by token + streaming=True, ) diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py index 61f2558..ac91e73 100644 --- a/agentrun/integration/langgraph/agent_converter.py +++ b/agentrun/integration/langgraph/agent_converter.py @@ -730,9 +730,19 @@ def _convert_astream_events_event( and not self.has_on_chat_model_stream ): chunk_data = data.get("chunk", {}) + messages = [] if isinstance(chunk_data, dict): messages = chunk_data.get("messages", []) + elif isinstance(chunk_data, list): + for item in chunk_data: + update = getattr(item, "update", None) + if not isinstance(update, dict): + continue + item_messages = update.get("messages", []) + if isinstance(item_messages, list): + messages.extend(item_messages) + if isinstance(messages, list): for msg in messages: content = AgentRunConverter._get_message_content(msg) if content: diff --git a/agentrun/model/__model_service_async_template.py b/agentrun/model/__model_service_async_template.py index 72d838f..5176801 100644 --- a/agentrun/model/__model_service_async_template.py +++ b/agentrun/model/__model_service_async_template.py @@ -233,4 +233,5 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: base_url=self.provider_settings.base_url, model=default_model, headers=cfg.get_headers(), + provider=self.provider, ) diff --git a/agentrun/model/model_service.py b/agentrun/model/model_service.py index f04a8b6..2e6452a 100644 --- a/agentrun/model/model_service.py +++ b/agentrun/model/model_service.py @@ -404,4 +404,5 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: base_url=self.provider_settings.base_url, model=default_model, headers=cfg.get_headers(), + provider=self.provider, ) diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py index 89b640a..047cce4 100644 --- a/agentrun/server/agui_protocol.py +++ b/agentrun/server/agui_protocol.py @@ -8,6 +8,7 @@ """ from dataclasses import dataclass, field +import json from typing import ( Any, AsyncIterator, @@ -30,6 +31,10 @@ import pydash from ..utils.helper import merge, MergeOptions +from ..utils.reasoning import ( + get_reasoning_content, + is_thinking_enabled_from_env, +) from .model import ( AgentEvent, AgentRequest, @@ -60,6 +65,14 @@ class TextState: message_id: str = field(default_factory=lambda: str(uuid.uuid4())) +@dataclass +class ReasoningState: + started: bool = False + message_started: bool = False + phase_id: str = field(default_factory=lambda: str(uuid.uuid4())) + message_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + @dataclass class ToolCallState: name: str = "" @@ -72,6 +85,7 @@ class ToolCallState: @dataclass class StreamStateMachine: text: TextState = field(default_factory=TextState) + reasoning: ReasoningState = field(default_factory=ReasoningState) tool_call_states: Dict[str, ToolCallState] = field(default_factory=dict) tool_result_chunks: Dict[str, List[str]] = field(default_factory=dict) run_errored: bool = False @@ -121,6 +135,43 @@ def cache_tool_result_chunk(self, tool_id: str, delta: str) -> None: def pop_tool_result_chunks(self, tool_id: str) -> str: return "".join(self.tool_result_chunks.pop(tool_id, [])) + def ensure_reasoning_started(self) -> Iterator[str]: + if not self.reasoning.started: + yield _encode_reasoning_event( + "REASONING_START", + messageId=self.reasoning.phase_id, + ) + self.reasoning.started = True + if not self.reasoning.message_started: + yield _encode_reasoning_event( + "REASONING_MESSAGE_START", + messageId=self.reasoning.message_id, + role="reasoning", + ) + self.reasoning.message_started = True + + def end_reasoning_if_open(self) -> Iterator[str]: + if self.reasoning.message_started: + yield _encode_reasoning_event( + "REASONING_MESSAGE_END", + messageId=self.reasoning.message_id, + ) + self.reasoning.message_started = False + if self.reasoning.started: + yield _encode_reasoning_event( + "REASONING_END", + messageId=self.reasoning.phase_id, + ) + self.reasoning = ReasoningState() + + +def _encode_reasoning_event(event_type: str, **payload: Any) -> str: + return ( + "data: " + + json.dumps({"type": event_type, **payload}, ensure_ascii=False) + + "\n\n" + ) + class AGUIProtocolHandler(BaseProtocolHandler): """AG-UI 协议处理器 @@ -376,6 +427,10 @@ async def _format_stream( if state.run_errored: return + # 结束未结束的 reasoning 消息 + for sse_data in state.end_reasoning_if_open(): + yield sse_data + # 结束所有未结束的工具调用 for sse_data in state.end_all_tools(self._encoder): yield sse_data @@ -399,8 +454,6 @@ def _process_event_with_boundaries( state: StreamStateMachine, ) -> Iterator[str]: """处理事件并注入边界事件""" - import json - from ag_ui.core import CustomEvent as AguiCustomEvent from ag_ui.core import ( RunErrorEvent, @@ -413,6 +466,8 @@ def _process_event_with_boundaries( ToolCallStartEvent, ) + thinking_enabled = is_thinking_enabled_from_env() + # RAW 事件直接透传 if event.event == EventType.RAW: raw_data = event.data.get("raw", "") @@ -422,9 +477,46 @@ def _process_event_with_boundaries( yield raw_data return + if event.event == EventType.REASONING: + if thinking_enabled: + reasoning_content = ( + event.data.get("delta") + or get_reasoning_content(event.data) + or "" + ) + if reasoning_content: + for sse_data in state.end_text_if_open(self._encoder): + yield sse_data + for sse_data in state.end_all_tools(self._encoder): + yield sse_data + for sse_data in state.ensure_reasoning_started(): + yield sse_data + yield _encode_reasoning_event( + "REASONING_MESSAGE_CONTENT", + messageId=state.reasoning.message_id, + delta=reasoning_content, + ) + return + # TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START # AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL if event.event == EventType.TEXT: + addition = self._strip_reasoning_from_addition( + event.addition, thinking_enabled + ) + addition_reasoning = get_reasoning_content(event.addition or {}) + if thinking_enabled and addition_reasoning: + for sse_data in state.ensure_reasoning_started(): + yield sse_data + yield _encode_reasoning_event( + "REASONING_MESSAGE_CONTENT", + messageId=state.reasoning.message_id, + delta=addition_reasoning, + ) + + for sse_data in state.end_reasoning_if_open(): + yield sse_data + for sse_data in state.end_all_tools(self._encoder): yield sse_data @@ -435,13 +527,13 @@ def _process_event_with_boundaries( message_id=state.text.message_id, delta=event.data.get("delta", ""), ) - if event.addition: + if addition: event_dict = agui_event.model_dump( by_alias=True, exclude_none=True ) event_dict = self._apply_addition( event_dict, - event.addition, + addition, event.addition_merge_options, ) json_str = json.dumps(event_dict, ensure_ascii=False) @@ -455,6 +547,9 @@ def _process_event_with_boundaries( tool_id = event.data.get("id", "") tool_name = event.data.get("name", "") + for sse_data in state.end_reasoning_if_open(): + yield sse_data + for sse_data in state.end_text_if_open(self._encoder): yield sse_data @@ -491,6 +586,9 @@ def _process_event_with_boundaries( tool_name = event.data.get("name", "") tool_args = event.data.get("args", "") + for sse_data in state.end_reasoning_if_open(): + yield sse_data + for sse_data in state.end_text_if_open(self._encoder): yield sse_data @@ -541,6 +639,9 @@ def _process_event_with_boundaries( timeout = event.data.get("timeout") schema = event.data.get("schema") + for sse_data in state.end_reasoning_if_open(): + yield sse_data + for sse_data in state.end_text_if_open(self._encoder): yield sse_data @@ -601,6 +702,9 @@ def _process_event_with_boundaries( tool_id = event.data.get("id", "") tool_name = event.data.get("name", "") + for sse_data in state.end_reasoning_if_open(): + yield sse_data + for sse_data in state.end_text_if_open(self._encoder): yield sse_data @@ -767,6 +871,29 @@ def _apply_addition( return merge(event_data, addition, **(merge_options or {})) + def _strip_reasoning_from_addition( + self, + addition: Optional[Dict[str, Any]], + thinking_enabled: bool, + ) -> Optional[Dict[str, Any]]: + if not addition: + return addition + + stripped = dict(addition) + stripped.pop("reasoning_content", None) + additional_kwargs = stripped.get("additional_kwargs") + if isinstance(additional_kwargs, dict): + additional_kwargs = dict(additional_kwargs) + additional_kwargs.pop("reasoning_content", None) + if additional_kwargs: + stripped["additional_kwargs"] = additional_kwargs + else: + stripped.pop("additional_kwargs", None) + + if not thinking_enabled: + return stripped + return stripped or None + async def _error_stream(self, message: str) -> AsyncIterator[str]: """生成错误事件流 diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index 438c509..d08b65b 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -29,6 +29,7 @@ InvokeAgentHandler, SyncInvokeAgentHandler, ) +from agentrun.utils.reasoning import get_reasoning_content class AgentInvoker: @@ -124,6 +125,9 @@ async def invoke_stream( # 处理用户返回的事件 for processed_event in self._process_user_event(item): yield processed_event + else: + for processed_event in self._wrap_model_chunk(item): + yield processed_event else: # 非流式结果 results = self._wrap_non_stream(raw_result) @@ -238,6 +242,11 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: data={"delta": item}, ) ) + else: + results.extend(self._wrap_model_chunk(item)) + + else: + results.extend(self._wrap_model_chunk(result)) return results @@ -267,6 +276,9 @@ async def _wrap_stream( elif isinstance(item, AgentEvent): for processed_event in self._process_user_event(item): yield processed_event + else: + for processed_event in self._wrap_model_chunk(item): + yield processed_event async def _iterate_async( self, content: Union[Iterator[Any], AsyncIterator[Any]] @@ -307,3 +319,31 @@ def _is_iterator(self, obj: Any) -> bool: if isinstance(obj, (str, bytes, dict, list, AgentEvent)): return False return hasattr(obj, "__iter__") or hasattr(obj, "__aiter__") + + def _wrap_model_chunk(self, item: Any) -> List[AgentEvent]: + """Convert common model chunks into AgentEvent objects.""" + events: List[AgentEvent] = [] + reasoning_content = get_reasoning_content(item) + if reasoning_content: + events.append( + AgentEvent( + event=EventType.REASONING, + data={"delta": reasoning_content}, + ) + ) + + content = self._read_attr_or_key(item, "content") + if isinstance(content, str) and content: + events.append( + AgentEvent( + event=EventType.TEXT, + data={"delta": content}, + ) + ) + + return events + + def _read_attr_or_key(self, obj: Any, key: str) -> Any: + if isinstance(obj, dict): + return obj.get(key) + return getattr(obj, key, None) diff --git a/agentrun/server/model.py b/agentrun/server/model.py index 1a7aed7..6033c48 100644 --- a/agentrun/server/model.py +++ b/agentrun/server/model.py @@ -14,7 +14,6 @@ Iterator, List, Optional, - TYPE_CHECKING, Union, ) @@ -91,6 +90,7 @@ class Message(BaseModel): id: Optional[str] = None role: MessageRole content: Optional[Union[str, List[Dict[str, Any]]]] = None + reasoning_content: Optional[str] = None name: Optional[str] = None tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[str] = None @@ -125,6 +125,7 @@ class EventType(str, Enum): # 核心事件(用户主要使用) # ========================================================================= TEXT = "TEXT" # 文本内容块 + REASONING = "REASONING" # 模型思考内容块 TOOL_CALL = "TOOL_CALL" # 完整工具调用(含 id, name, args) TOOL_CALL_CHUNK = "TOOL_CALL_CHUNK" # 工具调用参数片段(流式场景) TOOL_RESULT = "TOOL_RESULT" # 工具执行结果(最终结果,标识流式输出结束) diff --git a/agentrun/server/openai_protocol.py b/agentrun/server/openai_protocol.py index 5c82ccf..99977c7 100644 --- a/agentrun/server/openai_protocol.py +++ b/agentrun/server/openai_protocol.py @@ -15,6 +15,10 @@ from fastapi.responses import JSONResponse, StreamingResponse import pydash +from ..utils.reasoning import ( + get_reasoning_content, + is_thinking_enabled_from_env, +) from ..utils.helper import merge, MergeOptions from .model import ( AgentEvent, @@ -22,7 +26,6 @@ EventType, Message, MessageRole, - OpenAIProtocolConfig, ServerConfig, Tool, ToolCall, @@ -242,6 +245,7 @@ def _parse_messages( Message( role=role, content=msg_data.get("content"), + reasoning_content=msg_data.get("reasoning_content"), name=msg_data.get("name"), tool_calls=tool_calls, tool_call_id=msg_data.get("tool_call_id"), @@ -300,6 +304,7 @@ async def _format_stream( # 状态追踪 sent_role = False has_text = False + thinking_enabled = is_thinking_enabled_from_env() tool_call_index = -1 # 从 -1 开始,第一个工具调用时变为 0 # 工具调用状态:{tool_id: {"started": bool, "index": int}} tool_call_states: Dict[str, Dict[str, Any]] = {} @@ -336,9 +341,21 @@ async def _format_stream( event.addition_merge_options, ) + self._apply_reasoning_gate(delta, thinking_enabled) yield self._build_chunk(context, delta) continue + if event.event == EventType.REASONING: + if thinking_enabled: + reasoning_content = event.data.get("delta", "") + if reasoning_content: + has_text = True + yield self._build_chunk( + context, + {"reasoning_content": reasoning_content}, + ) + continue + # TOOL_CALL_CHUNK 事件 if event.event == EventType.TOOL_CALL_CHUNK: tool_id = event.data.get("id", "") @@ -384,6 +401,7 @@ async def _format_stream( event.addition_merge_options, ) + self._apply_reasoning_gate(delta, thinking_enabled) yield self._build_chunk(context, delta) continue @@ -458,6 +476,8 @@ def _format_non_stream( OpenAI 格式的响应字典 """ content_parts: List[str] = [] + reasoning_parts: List[str] = [] + thinking_enabled = is_thinking_enabled_from_env() # 工具调用状态:{tool_id: {id, name, arguments}} tool_call_map: Dict[str, Dict[str, Any]] = {} has_tool_calls = False @@ -465,6 +485,14 @@ def _format_non_stream( for event in events: if event.event == EventType.TEXT: content_parts.append(event.data.get("delta", "")) + reasoning_content = get_reasoning_content(event.addition or {}) + if thinking_enabled and reasoning_content: + reasoning_parts.append(reasoning_content) + + elif event.event == EventType.REASONING: + reasoning_content = event.data.get("delta", "") + if thinking_enabled and reasoning_content: + reasoning_parts.append(reasoning_content) elif event.event == EventType.TOOL_CALL_CHUNK: tool_id = event.data.get("id", "") @@ -493,6 +521,8 @@ def _format_non_stream( "role": "assistant", "content": content, } + if reasoning_parts: + message["reasoning_content"] = "".join(reasoning_parts) if tool_call_map: message["tool_calls"] = list(tool_call_map.values()) @@ -533,3 +563,19 @@ def _apply_addition( return delta return merge(delta, addition, **(merge_options or {})) + + def _apply_reasoning_gate( + self, + payload: Dict[str, Any], + thinking_enabled: bool, + ) -> None: + if thinking_enabled: + reasoning_content = get_reasoning_content(payload) + if reasoning_content is not None: + payload["reasoning_content"] = reasoning_content + return + + payload.pop("reasoning_content", None) + additional_kwargs = payload.get("additional_kwargs") + if isinstance(additional_kwargs, dict): + additional_kwargs.pop("reasoning_content", None) diff --git a/agentrun/utils/reasoning.py b/agentrun/utils/reasoning.py new file mode 100644 index 0000000..54b7828 --- /dev/null +++ b/agentrun/utils/reasoning.py @@ -0,0 +1,92 @@ +"""Utilities for reasoning content extraction and gating.""" + +import json +import os +from collections.abc import Mapping +from typing import Any, Optional + + +def parse_bool(value: Any) -> Optional[bool]: + """Parse loose boolean values used by env-provided model parameters.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "yes", "y", "on"}: + return True + if normalized in {"false", "0", "no", "n", "off"}: + return False + return None + + +def is_thinking_enabled_from_env( + environ: Mapping[str, str] = os.environ, +) -> bool: + """Return whether MODEL_PARAMETER_RULES enables thinking.""" + return get_thinking_value_from_env(environ) is True + + +def get_thinking_value_from_env( + environ: Mapping[str, str] = os.environ, +) -> Optional[bool]: + """Return the optional thinking value from MODEL_PARAMETER_RULES.""" + raw_rules = environ.get("MODEL_PARAMETER_RULES") + if not raw_rules: + return None + + try: + rules = json.loads(raw_rules) + except (TypeError, ValueError): + return None + + return _extract_thinking_value(rules) + + +def get_reasoning_content(chunk_or_message: Any) -> Optional[str]: + """Extract reasoning_content from common model chunk/message shapes.""" + value = _read_attr_or_key(chunk_or_message, "reasoning_content") + if value is not None: + return value + + additional_kwargs = _read_attr_or_key( + chunk_or_message, "additional_kwargs" + ) + if isinstance(additional_kwargs, Mapping): + value = additional_kwargs.get("reasoning_content") + if value is not None: + return value + + return None + + +def _extract_thinking_value(value: Any) -> Optional[bool]: + if isinstance(value, Mapping): + direct = parse_bool(value.get("thinking")) + if direct is not None: + return direct + + for nested_key in ("model_parameter_rules", "parameters", "rules"): + nested = value.get(nested_key) + nested_value = _extract_thinking_value(nested) + if nested_value is not None: + return nested_value + + if value.get("name") == "thinking": + for candidate_key in ("value", "default", "enabled"): + parsed = parse_bool(value.get(candidate_key)) + if parsed is not None: + return parsed + + if isinstance(value, list): + for item in value: + parsed = _extract_thinking_value(item) + if parsed is not None: + return parsed + + return None + + +def _read_attr_or_key(obj: Any, key: str) -> Any: + if isinstance(obj, Mapping): + return obj.get(key) + return getattr(obj, key, None) diff --git a/pyproject.toml b/pyproject.toml index 96b6ccb..9507141 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ server = [ langchain = [ "langchain>=1.0.0; python_version >= '3.10'", "langchain-openai>=1.0.0; python_version >= '3.10'", + "langchain-deepseek>=1.0.1; python_version >= '3.10'", ] google-adk = [ @@ -72,6 +73,8 @@ tablestore = [ langgraph = [ "langgraph>=0.2.0; python_version >= '3.10'", "langchain-core>=0.3.0; python_version >= '3.10'", + "langchain-openai>=1.0.0; python_version >= '3.10'", + "langchain-deepseek>=1.0.1; python_version >= '3.10'", ] [dependency-groups] diff --git a/scripts/smoke_reasoning_protocol.py b/scripts/smoke_reasoning_protocol.py new file mode 100644 index 0000000..df084bb --- /dev/null +++ b/scripts/smoke_reasoning_protocol.py @@ -0,0 +1,423 @@ +"""Smoke test reasoning content through AgentRun protocol handlers.""" + +import argparse +import asyncio +import json +import os +from typing import Any, Dict, Iterable, List, Optional + +from dotenv import load_dotenv +import httpx + +from agentrun.model import BackendType, ModelClient +from agentrun.model.api.data import ModelDataAPI +from agentrun.server import AgentEvent, AgentRequest, AgentRunServer, EventType +from agentrun.server.agui_protocol import AGUIProtocolHandler +from agentrun.server.openai_protocol import OpenAIProtocolHandler +from agentrun.utils.reasoning import ( + get_reasoning_content, + get_thinking_value_from_env, + is_thinking_enabled_from_env, + parse_bool, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--protocol", + choices=["openai", "agui", "both"], + default="both", + ) + parser.add_argument( + "--response-mode", + choices=["stream", "non-stream"], + default="stream", + ) + parser.add_argument("--env-file") + parser.add_argument("--model-resource") + parser.add_argument("--model") + parser.add_argument( + "--prompt", + default="用一句话回答:AgentRun 是什么?", + ) + parser.add_argument("--expect-reasoning", action="store_true") + parser.add_argument("--expect-no-reasoning", action="store_true") + parser.add_argument("--expect-content", action="store_true") + return parser.parse_args() + + +def load_env_file(path: Optional[str]) -> None: + if path: + load_dotenv(path, override=False) + if not os.getenv("AGENTRUN_REGION") and os.getenv("AGENTRUN_REGION_ID"): + os.environ["AGENTRUN_REGION"] = os.environ["AGENTRUN_REGION_ID"] + + +def model_parameter_rules() -> Dict[str, Any]: + raw = os.getenv("MODEL_PARAMETER_RULES") + if not raw: + return {} + try: + rules = json.loads(raw) + except ValueError: + return {} + return rules if isinstance(rules, dict) else {} + + +def model_call_kwargs() -> Dict[str, Any]: + kwargs = model_parameter_rules() + thinking = get_thinking_value_from_env() + direct_thinking = parse_bool(kwargs.pop("thinking", None)) + if direct_thinking is not None: + thinking = direct_thinking + + extra_body = kwargs.pop("extra_body", {}) + if not isinstance(extra_body, dict): + extra_body = {} + if thinking is not None: + extra_body["enable_thinking"] = thinking + if extra_body: + kwargs["extra_body"] = extra_body + return kwargs + + +def insecure_ssl_enabled() -> bool: + return os.getenv("AGENTRUN_SMOKE_INSECURE_SSL", "").lower() in { + "1", + "true", + "yes", + "on", + } + + +def read_value(obj: Any, key: str) -> Any: + if isinstance(obj, dict): + return obj.get(key) + return getattr(obj, key, None) + + +def iter_choices(response: Any) -> Iterable[Any]: + return read_value(response, "choices") or [] + + +def choice_message(choice: Any) -> Any: + return read_value(choice, "message") + + +def choice_delta(choice: Any) -> Any: + return read_value(choice, "delta") + + +def extract_content(obj: Any) -> Optional[str]: + value = read_value(obj, "content") + return value if isinstance(value, str) else None + + +def collect_non_stream(response: Any) -> tuple[str, str]: + content_parts: List[str] = [] + reasoning_parts: List[str] = [] + for choice in iter_choices(response): + message = choice_message(choice) + content = extract_content(message) + reasoning = get_reasoning_content(message) + if content: + content_parts.append(content) + if reasoning: + reasoning_parts.append(reasoning) + return "".join(content_parts), "".join(reasoning_parts) + + +async def build_agent_result(request: AgentRequest) -> Any: + model_resource = ( + os.getenv("AGENTRUN_MODEL_SERVICE_NAME") + or os.getenv("AGENTRUN_MODEL_PROXY_NAME") + or os.getenv("AGENTRUN_MODEL_NAME") + ) + model_name = os.getenv("AGENTRUN_MODEL_NAME") + if not model_resource: + raise RuntimeError( + "AGENTRUN_MODEL_SERVICE_NAME or AGENTRUN_MODEL_NAME is required" + ) + + messages = [ + { + "role": getattr(message.role, "value", message.role), + "content": message.content or "", + } + for message in request.messages + ] + kwargs = model_call_kwargs() + + if request.stream: + chunks = await call_real_model( + model_resource=model_resource, + model_name=model_name, + messages=messages, + stream=True, + kwargs=kwargs, + ) + + async def stream(): + for chunk in chunks: + for event in model_events_from_chunk(chunk): + yield event + + return stream() + + response = await call_real_model( + model_resource=model_resource, + model_name=model_name, + messages=messages, + stream=False, + kwargs=kwargs, + ) + content, reasoning = collect_non_stream(response) + events = [] + if reasoning: + events.append( + AgentEvent(event=EventType.REASONING, data={"delta": reasoning}) + ) + if content: + events.append(AgentEvent(event=EventType.TEXT, data={"delta": content})) + return events + + +def model_events_from_chunk(chunk: Any) -> List[AgentEvent]: + events: List[AgentEvent] = [] + for choice in iter_choices(chunk): + delta = choice_delta(choice) + content = extract_content(delta) + reasoning = ( + get_reasoning_content(delta) + or get_reasoning_content(choice) + or get_reasoning_content(chunk) + ) + if reasoning: + events.append( + AgentEvent( + event=EventType.REASONING, + data={"delta": reasoning}, + ) + ) + if content: + events.append( + AgentEvent(event=EventType.TEXT, data={"delta": content}) + ) + return events + + +async def call_real_model( + *, + model_resource: str, + model_name: Optional[str], + messages: List[Dict[str, str]], + stream: bool, + kwargs: Dict[str, Any], +) -> Any: + base_url, headers, default_model = resolve_model_endpoint(model_resource) + url = f"{base_url.rstrip('/')}/chat/completions" + payload = { + "model": model_name or default_model or model_resource, + "messages": messages, + "stream": stream, + **kwargs, + } + + async with httpx.AsyncClient( + timeout=180, verify=not insecure_ssl_enabled() + ) as client: + response = await client.post(url, headers=headers, json=payload) + + if response.is_error: + raise RuntimeError( + f"real model request failed: {response.status_code} {response.text}" + ) + + if not stream: + return response.json() + return parse_sse(response.text) + + +def resolve_model_endpoint( + model_resource: str, +) -> tuple[str, Dict[str, str], Optional[str]]: + if os.getenv("AGENTRUN_MODEL_SERVICE_NAME"): + service = ModelClient().get( + name=model_resource, backend_type=BackendType.SERVICE + ) + settings = service.provider_settings + if not settings or not settings.base_url or not settings.api_key: + raise RuntimeError( + f"model service {model_resource} has no provider settings" + ) + default_model = ( + settings.model_names[0] if settings.model_names else None + ) + return ( + settings.base_url, + { + "authorization": f"Bearer {settings.api_key}", + "content-type": "application/json", + }, + default_model, + ) + + data_api = ModelDataAPI(model_resource) + info = data_api.model_info() + return ( + info.base_url or "", + { + **(info.headers or {}), + "content-type": "application/json", + }, + info.model, + ) + + +async def call_openai(client: httpx.AsyncClient, args: argparse.Namespace): + response = await client.post( + "/openai/v1/chat/completions", + json={ + "model": args.model or os.getenv("AGENTRUN_MODEL_NAME"), + "stream": args.response_mode == "stream", + "messages": [{"role": "user", "content": args.prompt}], + }, + ) + if response.is_error: + raise AssertionError( + f"openai request failed: {response.status_code} {response.text}" + ) + + if args.response_mode == "stream": + events = parse_sse(response.text) + content = "".join( + event.get("choices", [{}])[0].get("delta", {}).get("content", "") + for event in events + if isinstance(event, dict) + ) + reasoning = "".join( + event.get("choices", [{}])[0] + .get("delta", {}) + .get("reasoning_content", "") + for event in events + if isinstance(event, dict) + ) + return content, reasoning, events + + payload = response.json() + message = payload.get("choices", [{}])[0].get("message", {}) + return ( + message.get("content") or "", + message.get("reasoning_content") or "", + payload, + ) + + +async def call_agui(client: httpx.AsyncClient, args: argparse.Namespace): + response = await client.post( + "/ag-ui/agent", + json={ + "threadId": "thread-local", + "runId": "run-local", + "messages": [{ + "id": "user-local", + "role": "user", + "content": args.prompt, + }], + "tools": [], + "context": [], + "forwardedProps": {}, + }, + ) + if response.is_error: + raise AssertionError( + f"agui request failed: {response.status_code} {response.text}" + ) + events = parse_sse(response.text) + content = "".join( + event.get("delta", "") + for event in events + if event.get("type") == "TEXT_MESSAGE_CONTENT" + ) + reasoning = "".join( + event.get("delta", "") + for event in events + if event.get("type") == "REASONING_MESSAGE_CONTENT" + ) + return content, reasoning, events + + +def parse_sse(text: str) -> List[Dict[str, Any]]: + events = [] + for line in text.splitlines(): + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + continue + events.append(json.loads(payload)) + return events + + +def validate_result(name: str, content: str, reasoning: str, args) -> None: + if args.expect_content and not content: + raise AssertionError(f"{name}: expected content but got empty content") + if args.expect_reasoning and not reasoning: + raise AssertionError(f"{name}: expected reasoning but got none") + if args.expect_no_reasoning and reasoning: + raise AssertionError(f"{name}: expected no reasoning but got one") + + +async def main() -> None: + args = parse_args() + load_env_file(args.env_file) + if args.model_resource: + os.environ["AGENTRUN_MODEL_SERVICE_NAME"] = args.model_resource + if args.model: + os.environ["AGENTRUN_MODEL_NAME"] = args.model + + app = AgentRunServer( + invoke_agent=build_agent_result, + protocols=[OpenAIProtocolHandler(), AGUIProtocolHandler()], + ).as_fastapi_app() + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, base_url="http://agentrun.local", timeout=180 + ) as client: + results = {} + if args.protocol in {"openai", "both"}: + content, reasoning, raw = await call_openai(client, args) + validate_result("openai", content, reasoning, args) + results["openai"] = summarize(content, reasoning, raw) + if args.protocol in {"agui", "both"}: + content, reasoning, raw = await call_agui(client, args) + validate_result("agui", content, reasoning, args) + results["agui"] = summarize(content, reasoning, raw) + + print( + json.dumps( + { + "thinkingEnabled": is_thinking_enabled_from_env(), + "protocol": args.protocol, + "responseMode": args.response_mode, + "results": results, + }, + ensure_ascii=False, + indent=2, + ) + ) + + +def summarize(content: str, reasoning: str, raw: Any) -> Dict[str, Any]: + return { + "contentPresent": bool(content), + "reasoningPresent": bool(reasoning), + "contentSample": content[:120], + "reasoningSample": reasoning[:120], + "rawEventCount": len(raw) if isinstance(raw, list) else None, + } + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/e2e/integration/langchain/test_agent_invoke_methods.py b/tests/e2e/integration/langchain/test_agent_invoke_methods.py index db04d5e..8fafd9a 100644 --- a/tests/e2e/integration/langchain/test_agent_invoke_methods.py +++ b/tests/e2e/integration/langchain/test_agent_invoke_methods.py @@ -430,7 +430,39 @@ def _normalize_agui_event(event: Dict[str, Any]) -> Dict[str, Any]: }, {"type": "TEXT_MESSAGE_END", "hasMessageId": True}, {"type": "RUN_FINISHED", "hasThreadId": True, "hasRunId": True}, - ] + ], + [ + {"type": "RUN_STARTED", "hasThreadId": True, "hasRunId": True}, + { + "type": "TOOL_CALL_START", + "toolCallName": "get_time", + "hasToolCallId": True, + }, + { + "type": "TOOL_CALL_ARGS", + "delta": "{}", + "hasToolCallId": True, + }, + {"type": "TOOL_CALL_END", "hasToolCallId": True}, + { + "type": "TOOL_CALL_RESULT", + "role": "tool", + "hasToolCallId": True, + "hasMessageId": True, + }, + { + "type": "TEXT_MESSAGE_START", + "role": "assistant", + "hasMessageId": True, + }, + { + "type": "TEXT_MESSAGE_CONTENT", + "delta": "工具结果已收到: 2024-01-01 12:00:00", + "hasMessageId": True, + }, + {"type": "TEXT_MESSAGE_END", "hasMessageId": True}, + {"type": "RUN_FINISHED", "hasThreadId": True, "hasRunId": True}, + ], ], } @@ -562,6 +594,15 @@ def _normalize_openai_stream( }], "finish_reason": None, }, + { + "object": "chat.completion.chunk", + "tool_calls": [{ + "name": None, + "arguments": "{}", + "has_id": False, + }], + "finish_reason": None, + }, { "object": "chat.completion.chunk", "delta_role": "assistant", @@ -623,7 +664,7 @@ def _normalize_openai_nonstream(resp: Dict[str, Any]) -> Dict[str, Any]: "content": "工具结果已收到: 2024-01-01 12:00:00", "tool_calls": [{ "name": "get_time", - "arguments": "", + "arguments": "{}", "has_id": True, }], "finish_reason": "tool_calls", @@ -810,9 +851,7 @@ async def test_astream_events( async def test_convert_python_3_10(self): from langchain.messages import ( AIMessage, - AIMessageChunk, HumanMessage, - SystemMessage, ) events = [ diff --git a/tests/e2e/test_langchain_server_demo.py b/tests/e2e/test_langchain_server_demo.py new file mode 100644 index 0000000..1a96073 --- /dev/null +++ b/tests/e2e/test_langchain_server_demo.py @@ -0,0 +1,394 @@ +"""LangChain AgentRunServer demo e2e case.""" + +from dataclasses import dataclass +import inspect +import json +import os +import socket +import threading +import time +from typing import Any, Dict, Generator, List + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.testclient import TestClient +import httpx +from langchain.agents import create_agent +import pydash +import pytest +import uvicorn + +from agentrun.integration.langchain import ( + AgentRunConverter, + model, + sandbox_toolset, +) +from agentrun.integration.utils.model import CommonModel +import agentrun.memory_collection as memory_collection_module +from agentrun.model import ModelService, ModelType, ProviderSettings +from agentrun.sandbox import TemplateType +from agentrun.server import AgentRequest, AgentRunServer + +MODEL_NAME = "demo-model" +MODEL_SERVICE_NAME = "demo-model-service" +MEMORY_COLLECTION_NAME = "cafe-mem" +MODEL_REPLY = "你好,我是网页分析助手。" + + +@dataclass +class MockOpenAIServer: + base_url: str + requests: List[Dict[str, Any]] + + +@dataclass +class DemoRuntime: + client: TestClient + mock_openai: MockOpenAIServer + memory_collection_names: List[str] + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def _sse(data: Dict[str, Any]) -> str: + return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def _chat_payload(model_name: str, content: str) -> Dict[str, Any]: + return { + "id": "chatcmpl-demo", + "object": "chat.completion", + "created": int(time.time()), + "model": model_name, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + }], + } + + +async def _stream_chat(model_name: str, content: str): + yield _sse({ + "id": "chatcmpl-demo", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + }], + }) + yield _sse({ + "id": "chatcmpl-demo", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {"content": content}, + "finish_reason": None, + }], + }) + yield _sse({ + "id": "chatcmpl-demo", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + }) + yield "data: [DONE]\n\n" + + +def _build_mock_openai_app( + request_log: List[Dict[str, Any]], +) -> FastAPI: + app = FastAPI() + + @app.get("/v1/models") + async def list_models(): + return { + "object": "list", + "data": [ + {"id": MODEL_NAME, "object": "model", "owned_by": "local"} + ], + } + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request): + body = await request.json() + request_log.append({ + "headers": dict(request.headers), + "body": body, + }) + + model_name = body.get("model") or MODEL_NAME + if body.get("stream"): + return StreamingResponse( + _stream_chat(model_name, MODEL_REPLY), + media_type="text/event-stream", + ) + + return JSONResponse(_chat_payload(model_name, MODEL_REPLY)) + + return app + + +def _parse_sse_events(content: str) -> List[Dict[str, Any]]: + events: List[Dict[str, Any]] = [] + for line in content.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + data = line[5:].strip() + if not data or data == "[DONE]": + continue + events.append(json.loads(data)) + return events + + +def _openai_content(events: List[Dict[str, Any]]) -> str: + chunks = [] + for event in events: + choice = (event.get("choices") or [{}])[0] or {} + delta = choice.get("delta") or {} + chunks.append(delta.get("content", "")) + return "".join(chunks) + + +def _agui_content(events: List[Dict[str, Any]]) -> str: + return "".join( + event.get("delta", "") + for event in events + if event.get("type") == "TEXT_MESSAGE_CONTENT" + ) + + +def _sandbox_tools() -> List[Any]: + sandbox_name = os.getenv("SANDBOX_NAME") + if sandbox_name and not sandbox_name.startswith("请替换"): + return sandbox_toolset( + template_name=sandbox_name, + template_type=TemplateType.CODE_INTERPRETER, + sandbox_idle_timeout_seconds=300, + ) + return [] + + +async def _yield_agent_handler_result( + request: AgentRequest, + agent_handler: Any, +): + result = agent_handler(request) + if inspect.isawaitable(result): + result = await result + + if hasattr(result, "__aiter__"): + async for item in result: + yield item + return + + if result is not None: + yield result + + +def _build_demo_app( + model_service: ModelService, + monkeypatch: pytest.MonkeyPatch, +) -> FastAPI: + monkeypatch.setenv("MODEL_NAME", MODEL_NAME) + monkeypatch.setenv("MODEL_SERVICE_NAME", MODEL_SERVICE_NAME) + monkeypatch.setenv("OPENAI_API_KEY", "agentrun") + monkeypatch.setenv("SANDBOX_NAME", "") + + original_get_model_info = CommonModel.get_model_info + + def get_model_info(self, config=None): + info = original_get_model_info(self, config) + if not (info.api_key or "").strip(): + fallback = os.getenv("OPENAI_API_KEY", "agentrun").strip() + if fallback: + info.api_key = "agentrun" + return info + + monkeypatch.setattr(CommonModel, "get_model_info", get_model_info) + + memory_collection_names: List[str] = [] + + class MemoryConversationPassthrough: + + def __init__(self, memory_collection_name: str): + memory_collection_names.append(memory_collection_name) + + async def wrap_invoke_agent(self, request, agent_handler): + async for item in _yield_agent_handler_result( + request, agent_handler + ): + yield item + + monkeypatch.setattr( + memory_collection_module, + "MemoryConversation", + MemoryConversationPassthrough, + ) + + if not os.getenv("MODEL_SERVICE_NAME"): + raise ValueError("请将 MODEL_SERVICE_NAME 替换为您已经创建的模型名称") + + agent = create_agent( + model=model(model_service, model=os.getenv("MODEL_NAME")), + tools=[*_sandbox_tools()], + system_prompt="你是网页分析助手。", + ) + + def invoke_agent(request: AgentRequest): + input_data: Any = { + "messages": [ + { + "content": message.content, + "role": ( + message.role.value + if hasattr(message.role, "value") + else str(message.role) + ), + } + for message in request.messages + ] + } + converter = AgentRunConverter() + + if request.stream: + + async def stream_generator(): + result = agent.astream_events(input_data, version="v2") + async for chunk in result: + for item in converter.convert(chunk): + yield item + + return stream_generator() + + result = agent.invoke(input_data) + return pydash.get(result, "messages.-1.content") + + app = AgentRunServer( + invoke_agent=invoke_agent, + memory_collection_name=MEMORY_COLLECTION_NAME, + ).as_fastapi_app() + app.state.memory_collection_names = memory_collection_names + return app + + +@pytest.fixture(scope="module") +def mock_openai_server() -> Generator[MockOpenAIServer, None, None]: + request_log: List[Dict[str, Any]] = [] + app = _build_mock_openai_app(request_log) + port = _find_free_port() + config = uvicorn.Config( + app, host="127.0.0.1", port=port, log_level="warning" + ) + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + base_url = f"http://127.0.0.1:{port}" + for _ in range(50): + try: + httpx.get(f"{base_url}/v1/models", timeout=0.2) + break + except Exception: + time.sleep(0.1) + else: + server.should_exit = True + thread.join(timeout=5) + raise RuntimeError("mock OpenAI server did not start") + + yield MockOpenAIServer(base_url=base_url, requests=request_log) + + server.should_exit = True + thread.join(timeout=5) + + +@pytest.fixture +def demo_runtime( + mock_openai_server: MockOpenAIServer, + monkeypatch: pytest.MonkeyPatch, +) -> Generator[DemoRuntime, None, None]: + mock_openai_server.requests.clear() + model_service = ModelService( + model_service_name=MODEL_SERVICE_NAME, + model_type=ModelType.LLM, + provider="openai", + provider_settings=ProviderSettings( + api_key="", + base_url=f"{mock_openai_server.base_url}/v1", + model_names=[MODEL_NAME], + ), + ) + app = _build_demo_app(model_service, monkeypatch) + + with TestClient(app) as client: + yield DemoRuntime( + client=client, + mock_openai=mock_openai_server, + memory_collection_names=app.state.memory_collection_names, + ) + + +def test_langchain_demo_openai_streaming(demo_runtime: DemoRuntime): + response = demo_runtime.client.post( + "/openai/v1/chat/completions", + json={ + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "你好?"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + assert "data: [DONE]" in response.text + assert MODEL_REPLY == _openai_content(_parse_sse_events(response.text)) + assert MEMORY_COLLECTION_NAME in demo_runtime.memory_collection_names + + +def test_langchain_demo_openai_plain_response(demo_runtime: DemoRuntime): + response = demo_runtime.client.post( + "/openai/v1/chat/completions", + json={ + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "你好?"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert pydash.get(payload, "choices.0.message.content") == MODEL_REPLY + assert any( + item["headers"].get("authorization") == "Bearer agentrun" + for item in demo_runtime.mock_openai.requests + ) + + +def test_langchain_demo_agui_events(demo_runtime: DemoRuntime): + response = demo_runtime.client.post( + "/ag-ui/agent", + json={ + "messages": [ + {"role": "user", "content": "写一段代码,查询现在是几点?"} + ] + }, + ) + + assert response.status_code == 200 + events = _parse_sse_events(response.text) + event_types = [event.get("type") for event in events] + assert "RUN_STARTED" in event_types + assert "TEXT_MESSAGE_CONTENT" in event_types + assert "RUN_FINISHED" in event_types + assert MODEL_REPLY == _agui_content(events) diff --git a/tests/e2e/test_reasoning_protocol.py b/tests/e2e/test_reasoning_protocol.py new file mode 100644 index 0000000..86fda8e --- /dev/null +++ b/tests/e2e/test_reasoning_protocol.py @@ -0,0 +1,153 @@ +"""E2E coverage for reasoning_content protocol output.""" + +import json +from types import SimpleNamespace +from typing import Any, Dict, List + +import httpx +import pytest + +from agentrun.server import AgentRequest, AgentRunServer + + +def _parse_sse_events(content: str) -> List[Dict[str, Any]]: + events = [] + for line in content.splitlines(): + if not line.startswith("data: "): + continue + payload = line[6:] + if payload == "[DONE]": + continue + events.append(json.loads(payload)) + return events + + +@pytest.fixture +def reasoning_app(): + async def invoke_agent(request: AgentRequest): + yield SimpleNamespace( + content="", + additional_kwargs={"reasoning_content": "thinking"}, + ) + yield SimpleNamespace(content="answer", additional_kwargs={}) + + return AgentRunServer(invoke_agent=invoke_agent).as_fastapi_app() + + +async def _post_json(app, path: str, payload: Dict[str, Any]) -> httpx.Response: + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + return await client.post(path, json=payload, timeout=60.0) + + +def _set_thinking(monkeypatch, enabled: bool) -> None: + monkeypatch.setenv( + "MODEL_PARAMETER_RULES", + json.dumps({"thinking": enabled}), + ) + + +@pytest.mark.parametrize("thinking_enabled", [True, False]) +@pytest.mark.asyncio +async def test_openai_stream_reasoning_content_gate( + reasoning_app, + monkeypatch, + thinking_enabled: bool, +): + _set_thinking(monkeypatch, thinking_enabled) + + response = await _post_json( + reasoning_app, + "/openai/v1/chat/completions", + { + "model": "mock-model", + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + events = _parse_sse_events(response.text) + deltas = [ + (event.get("choices") or [{}])[0].get("delta") or {} + for event in events + ] + reasoning = "".join(delta.get("reasoning_content", "") for delta in deltas) + content = "".join(delta.get("content", "") for delta in deltas) + + assert content == "answer" + assert reasoning == ("thinking" if thinking_enabled else "") + assert all("additional_kwargs" not in delta for delta in deltas) + + +@pytest.mark.parametrize("thinking_enabled", [True, False]) +@pytest.mark.asyncio +async def test_openai_non_stream_reasoning_content_gate( + reasoning_app, + monkeypatch, + thinking_enabled: bool, +): + _set_thinking(monkeypatch, thinking_enabled) + + response = await _post_json( + reasoning_app, + "/openai/v1/chat/completions", + { + "model": "mock-model", + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + message = response.json()["choices"][0]["message"] + assert message["content"] == "answer" + if thinking_enabled: + assert message["reasoning_content"] == "thinking" + else: + assert "reasoning_content" not in message + + +@pytest.mark.parametrize("thinking_enabled", [True, False]) +@pytest.mark.asyncio +async def test_agui_reasoning_events_gate( + reasoning_app, + monkeypatch, + thinking_enabled: bool, +): + _set_thinking(monkeypatch, thinking_enabled) + + response = await _post_json( + reasoning_app, + "/ag-ui/agent", + {"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + events = _parse_sse_events(response.text) + event_types = [event["type"] for event in events] + reasoning = "".join( + event.get("delta", "") + for event in events + if event["type"] == "REASONING_MESSAGE_CONTENT" + ) + content = "".join( + event.get("delta", "") + for event in events + if event["type"] == "TEXT_MESSAGE_CONTENT" + ) + + assert content == "answer" + if thinking_enabled: + assert reasoning == "thinking" + assert event_types.index("REASONING_MESSAGE_CONTENT") < event_types.index( + "TEXT_MESSAGE_START" + ) + else: + assert reasoning == "" + assert all( + not event_type.startswith("REASONING") + for event_type in event_types + ) diff --git a/tests/unittests/integration/test_langchain.py b/tests/unittests/integration/test_langchain.py index 712842a..4cecb0b 100644 --- a/tests/unittests/integration/test_langchain.py +++ b/tests/unittests/integration/test_langchain.py @@ -15,7 +15,9 @@ from agentrun.integration.builtin.model import CommonModel from agentrun.integration.utils.tool import CommonToolSet, tool +from agentrun.model import ModelService, ProviderSettings from agentrun.model.model_proxy import ModelProxy +from agentrun.utils.config import Config from .base import IntegrationTestBase, IntegrationTestResult, ToolCallInfo from .mock_llm_server import MockLLMServer @@ -173,6 +175,20 @@ def _msg_to_dict(self, msg: Any) -> dict: class TestLangChainIntegration(LangChainTestMixin): """LangChain Integration 测试类""" + def _model_service_model(self, provider: Optional[str]) -> CommonModel: + return CommonModel( + ModelService( + model_service_name=f"{provider or 'default'}-service", + provider=provider, + provider_settings=ProviderSettings( + api_key="sk-test", + base_url="https://model.example/v1", + model_names=["test-model"], + ), + ), + config=Config(headers={"x-test-header": "yes"}), + ) + @pytest.fixture def mock_server(self, monkeypatch: Any, respx_mock: Any) -> MockLLMServer: """创建并安装 Mock LLM Server @@ -320,6 +336,49 @@ def test_stream_options_in_requests( assert llm.stream_usage is True assert llm.streaming is True + def test_model_service_model_info_exposes_provider(self): + model = self._model_service_model("deepseek") + + assert model.get_model_info().provider == "deepseek" + + def test_deepseek_provider_uses_chat_deepseek(self): + from langchain_deepseek import ChatDeepSeek + + model = self._model_service_model("deepseek") + + llm = model.to_langchain() + + assert isinstance(llm, ChatDeepSeek) + assert llm.name == "test-model" + assert llm.model_name == "test-model" + assert llm.api_base == "https://model.example/v1" + assert llm.default_headers == {"x-test-header": "yes"} + assert llm.openai_api_key.get_secret_value() == "sk-test" + assert llm.stream_usage is True + assert llm.streaming is True + + @pytest.mark.parametrize( + "provider", + [None, "custom", "tongyi", "zhipuai", "moonshot", "minimax", "unknown"], + ) + def test_non_deepseek_providers_use_chat_openai( + self, provider: Optional[str] + ): + from langchain_openai import ChatOpenAI + + model = self._model_service_model(provider) + + llm = model.to_langchain() + + assert isinstance(llm, ChatOpenAI) + assert llm.name == "test-model" + assert llm.model_name == "test-model" + assert llm.openai_api_base == "https://model.example/v1" + assert llm.default_headers == {"x-test-header": "yes"} + assert llm.openai_api_key.get_secret_value() == "sk-test" + assert llm.stream_usage is True + assert llm.streaming is True + def test_stream_options_validation( self, mock_server: MockLLMServer, diff --git a/tests/unittests/integration/test_langchain_convert.py b/tests/unittests/integration/test_langchain_convert.py index 16feeff..fc35e08 100644 --- a/tests/unittests/integration/test_langchain_convert.py +++ b/tests/unittests/integration/test_langchain_convert.py @@ -105,6 +105,24 @@ def test_is_stream_values_format(self): class TestConvertAstreamEventsFormat: """测试 astream_events 格式的事件转换""" + def test_on_chain_stream_model_command_update_text(self): + """测试 Command(update={messages: ...}) 形式的模型输出""" + + class CommandLike: + + def __init__(self): + self.update = {"messages": [create_mock_ai_message("你好")]} + + event = { + "event": "on_chain_stream", + "name": "model", + "data": {"chunk": [CommandLike()]}, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert results == ["你好"] + def test_on_chat_model_stream_text_content(self): """测试 on_chat_model_stream 事件的文本内容提取""" chunk = create_mock_ai_message_chunk("你好") diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index e7196cb..eefc20b 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -4,7 +4,7 @@ """ import json -from typing import cast +from types import SimpleNamespace from fastapi.testclient import TestClient import pytest @@ -36,6 +36,14 @@ def test_get_prefix_custom(self): assert handler.get_prefix() == "/custom/agui" +def _agui_sse_events(response): + return [ + json.loads(line[6:]) + for line in response.text.splitlines() + if line.startswith("data: {") + ] + + class TestAGUIProtocolEndpoints: """测试 AG-UI 协议端点""" @@ -1185,3 +1193,151 @@ async def invoke_agent(request: AgentRequest): assert "TOOL_CALL_START" in types assert "TOOL_CALL_END" in types assert "TOOL_CALL_RESULT" in types + + +class TestAGUIReasoningContent: + """测试 AG-UI reasoning 事件输出开关""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + def test_stream_includes_reasoning_when_thinking_enabled(self, monkeypatch): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}') + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.REASONING, + data={"delta": "thinking"}, + ) + yield AgentEvent(event=EventType.TEXT, data={"delta": "answer"}) + + response = self.get_client(invoke_agent).post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + events = _agui_sse_events(response) + types = [event["type"] for event in events] + reasoning_event = next( + event + for event in events + if event["type"] == "REASONING_MESSAGE_CONTENT" + ) + assert "REASONING_START" in types + assert reasoning_event["delta"] == "thinking" + assert "TEXT_MESSAGE_CONTENT" in types + + def test_stream_suppresses_reasoning_when_thinking_disabled( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}') + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.REASONING, + data={"delta": "thinking"}, + ) + yield AgentEvent(event=EventType.TEXT, data={"delta": "answer"}) + + response = self.get_client(invoke_agent).post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + events = _agui_sse_events(response) + assert "REASONING_MESSAGE_CONTENT" not in [ + event["type"] for event in events + ] + text_event = next( + event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT" + ) + assert text_event["delta"] == "answer" + + def test_stream_promotes_chunk_additional_kwargs_reasoning( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}') + + async def invoke_agent(request: AgentRequest): + yield SimpleNamespace( + content="answer", + additional_kwargs={"reasoning_content": "thinking"}, + ) + + response = self.get_client(invoke_agent).post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + events = _agui_sse_events(response) + reasoning_event = next( + event + for event in events + if event["type"] == "REASONING_MESSAGE_CONTENT" + ) + text_event = next( + event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT" + ) + assert reasoning_event["delta"] == "thinking" + assert text_event["delta"] == "answer" + + def test_text_addition_reasoning_is_emitted_before_text( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}') + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "answer"}, + addition={ + "additional_kwargs": {"reasoning_content": "thinking"} + }, + ) + + response = self.get_client(invoke_agent).post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + events = _agui_sse_events(response) + types = [event["type"] for event in events] + assert types.index("REASONING_MESSAGE_CONTENT") < types.index( + "TEXT_MESSAGE_START" + ) + assert "REASONING_MESSAGE_END" in types + assert "REASONING_END" in types + text_event = next( + event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT" + ) + assert text_event["delta"] == "answer" + assert "additional_kwargs" not in text_event + + def test_text_addition_reasoning_is_stripped_when_thinking_disabled( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}') + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "answer"}, + addition={ + "additional_kwargs": {"reasoning_content": "thinking"} + }, + ) + + response = self.get_client(invoke_agent).post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + events = _agui_sse_events(response) + types = [event["type"] for event in events] + assert all(not event_type.startswith("REASONING") for event_type in types) + text_event = next( + event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT" + ) + assert text_event["delta"] == "answer" + assert "additional_kwargs" not in text_event diff --git a/tests/unittests/server/test_openai_protocol.py b/tests/unittests/server/test_openai_protocol.py index 53e2c30..1b15052 100644 --- a/tests/unittests/server/test_openai_protocol.py +++ b/tests/unittests/server/test_openai_protocol.py @@ -4,7 +4,7 @@ """ import json -from typing import cast +from types import SimpleNamespace from fastapi.testclient import TestClient import pytest @@ -51,6 +51,14 @@ def test_get_model_name_custom(self): assert handler.get_model_name() == "custom-model" +def _openai_sse_events(response): + return [ + json.loads(line[6:]) + for line in response.text.splitlines() + if line.startswith("data: {") + ] + + class TestOpenAIProtocolEndpoints: """测试 OpenAI 协议端点""" @@ -1006,3 +1014,157 @@ def invoke_agent(request: AgentRequest): tool_calls = data["choices"][0]["message"]["tool_calls"] assert len(tool_calls) == 1 assert tool_calls[0]["function"]["arguments"] == "" + + +class TestOpenAIReasoningContent: + """测试 OpenAI reasoning_content 输出开关""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + def test_stream_includes_reasoning_when_thinking_enabled(self, monkeypatch): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}') + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.REASONING, + data={"delta": "thinking"}, + ) + yield AgentEvent(event=EventType.TEXT, data={"delta": "answer"}) + + response = self.get_client(invoke_agent).post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + events = _openai_sse_events(response) + assert events[0]["choices"][0]["delta"]["reasoning_content"] == "thinking" + assert events[1]["choices"][0]["delta"]["content"] == "answer" + + def test_stream_suppresses_reasoning_when_thinking_disabled( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}') + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.REASONING, + data={"delta": "thinking"}, + ) + yield AgentEvent(event=EventType.TEXT, data={"delta": "answer"}) + + response = self.get_client(invoke_agent).post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + events = _openai_sse_events(response) + assert all( + "reasoning_content" not in event["choices"][0]["delta"] + for event in events + ) + assert events[0]["choices"][0]["delta"]["content"] == "answer" + + def test_non_stream_includes_reasoning_when_thinking_enabled( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}') + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent( + event=EventType.REASONING, + data={"delta": "thinking"}, + ), + AgentEvent(event=EventType.TEXT, data={"delta": "answer"}), + ] + + response = self.get_client(invoke_agent).post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + message = response.json()["choices"][0]["message"] + assert message["content"] == "answer" + assert message["reasoning_content"] == "thinking" + + def test_non_stream_suppresses_reasoning_when_thinking_disabled( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}') + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent( + event=EventType.REASONING, + data={"delta": "thinking"}, + ), + AgentEvent(event=EventType.TEXT, data={"delta": "answer"}), + ] + + response = self.get_client(invoke_agent).post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + message = response.json()["choices"][0]["message"] + assert message["content"] == "answer" + assert "reasoning_content" not in message + + def test_stream_promotes_chunk_additional_kwargs_reasoning( + self, monkeypatch + ): + monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}') + + async def invoke_agent(request: AgentRequest): + yield SimpleNamespace( + content="answer", + additional_kwargs={"reasoning_content": "thinking"}, + ) + + response = self.get_client(invoke_agent).post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + events = _openai_sse_events(response) + assert events[0]["choices"][0]["delta"]["reasoning_content"] == "thinking" + assert events[1]["choices"][0]["delta"]["content"] == "answer" + + def test_parses_request_message_reasoning_content(self): + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["messages"] = request.messages + return "Done" + + response = self.get_client(invoke_agent).post( + "/openai/v1/chat/completions", + json={ + "messages": [{ + "role": "assistant", + "content": "answer", + "reasoning_content": "thinking", + }], + "stream": False, + }, + ) + + assert response.status_code == 200 + assert captured_request["messages"][0].reasoning_content == "thinking" diff --git a/tests/unittests/server/test_reasoning.py b/tests/unittests/server/test_reasoning.py new file mode 100644 index 0000000..51cadf4 --- /dev/null +++ b/tests/unittests/server/test_reasoning.py @@ -0,0 +1,84 @@ +from types import SimpleNamespace + +import pytest + +from agentrun.server import AgentRequest, EventType, MessageRole +from agentrun.server.invoker import AgentInvoker +from agentrun.server.model import Message +from agentrun.utils.reasoning import ( + get_reasoning_content, + get_thinking_value_from_env, + is_thinking_enabled_from_env, +) + + +def test_model_parameter_rules_object_enables_thinking(): + env = {"MODEL_PARAMETER_RULES": '{"thinking": true}'} + + assert is_thinking_enabled_from_env(env) is True + + +def test_model_parameter_rules_list_enables_thinking(): + env = { + "MODEL_PARAMETER_RULES": ( + '[{"name": "temperature", "default": 0.1}, ' + '{"name": "thinking", "default": "true"}]' + ) + } + + assert is_thinking_enabled_from_env(env) is True + assert get_thinking_value_from_env(env) is True + + +def test_model_parameter_rules_nested_parameters_disables_thinking(): + env = { + "MODEL_PARAMETER_RULES": ( + '{"parameters": [{"name": "thinking", "default": "false"}]}' + ) + } + + assert is_thinking_enabled_from_env(env) is False + assert get_thinking_value_from_env(env) is False + + +def test_model_parameter_rules_invalid_json_disables_thinking(): + env = {"MODEL_PARAMETER_RULES": "not json"} + + assert is_thinking_enabled_from_env(env) is False + + +def test_get_reasoning_content_from_attribute(): + chunk = SimpleNamespace(reasoning_content="thinking") + + assert get_reasoning_content(chunk) == "thinking" + + +def test_get_reasoning_content_from_additional_kwargs(): + chunk = {"additional_kwargs": {"reasoning_content": "thinking"}} + + assert get_reasoning_content(chunk) == "thinking" + + +@pytest.mark.asyncio +async def test_invoker_converts_chunk_additional_kwargs_to_reasoning_event(): + chunk = SimpleNamespace( + content="answer", + additional_kwargs={"reasoning_content": "thinking"}, + ) + + async def invoke_agent(request): + yield chunk + + request = AgentRequest( + protocol="openai", + messages=[Message(role=MessageRole.USER, content="hello")], + stream=True, + raw_request=None, + ) + + events = [event async for event in AgentInvoker(invoke_agent).invoke_stream(request)] + + assert events[0].event == EventType.REASONING + assert events[0].data["delta"] == "thinking" + assert events[1].event == EventType.TEXT + assert events[1].data["delta"] == "answer"