Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions agentrun/integration/langchain/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

将 CommonModel 包装为 LangChain BaseChatModel。"""

import inspect
import json
from typing import Any, List, Optional
from typing import Any

from agentrun.integration.langchain.message_adapter import (
LangChainMessageAdapter,
)
from agentrun.integration.utils.adapter import ModelAdapter

_DEEPSEEK_PROVIDER = "deepseek"


class LangChainModelAdapter(ModelAdapter):
"""LangChain 模型适配器 / LangChain Model Adapter
Expand All @@ -23,15 +23,51 @@ def __init__(self):

def wrap_model(self, common_model: Any) -> Any:
"""包装 CommonModel 为 LangChain BaseChatModel / LangChain Model Adapter"""
from langchain_openai import ChatOpenAI

info = common_model.get_model_info() # 确保模型可用
provider = (info.provider or "").lower()

if provider == _DEEPSEEK_PROVIDER:
return self._create_reasoning_model(info)
return self._create_openai_model(info)
Comment on lines +27 to +31

def _create_reasoning_model(self, info: Any) -> Any:
"""创建支持 reasoning_content 的模型(使用 ChatDeepSeek)"""
try:
from langchain_deepseek import ChatDeepSeek
except ImportError as e:
raise ImportError(
"import langchain_deepseek failed. "
"Install it with: pip install 'agentrun-sdk[langchain]' "
"or pip install 'agentrun-sdk[langgraph]'"
) from e

return ChatDeepSeek(
name=info.model,
model=info.model,
api_key=info.api_key,
api_base=info.base_url,
default_headers=info.headers,
stream_usage=True,
streaming=True,
)
Comment on lines +44 to +52
Comment on lines +33 to +52

def _create_openai_model(self, info: Any) -> Any:
"""创建标准 OpenAI 兼容模型"""
try:
from langchain_openai import ChatOpenAI
except ImportError as e:
raise ImportError(
"import langchain_openai failed. "
"Install it with: pip install 'agentrun-sdk[langchain]' "
"or pip install 'agentrun-sdk[langgraph]'"
) from e

return ChatOpenAI(
name=info.model,
api_key=info.api_key,
model=info.model,
base_url=info.base_url,
default_headers=info.headers,
stream_usage=True,
streaming=True, # 启用流式输出以支持 token by token
streaming=True,
)
10 changes: 10 additions & 0 deletions agentrun/integration/langgraph/agent_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,19 @@ def _convert_astream_events_event(
and not self.has_on_chat_model_stream
):
chunk_data = data.get("chunk", {})
messages = []
if isinstance(chunk_data, dict):
messages = chunk_data.get("messages", [])
elif isinstance(chunk_data, list):
for item in chunk_data:
update = getattr(item, "update", None)
if not isinstance(update, dict):
continue
item_messages = update.get("messages", [])
if isinstance(item_messages, list):
messages.extend(item_messages)

if isinstance(messages, list):
for msg in messages:
content = AgentRunConverter._get_message_content(msg)
if content:
Expand Down
1 change: 1 addition & 0 deletions agentrun/model/__model_service_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,5 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
base_url=self.provider_settings.base_url,
model=default_model,
headers=cfg.get_headers(),
provider=self.provider,
)
1 change: 1 addition & 0 deletions agentrun/model/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,5 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
base_url=self.provider_settings.base_url,
model=default_model,
headers=cfg.get_headers(),
provider=self.provider,
)
135 changes: 131 additions & 4 deletions agentrun/server/agui_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from dataclasses import dataclass, field
import json
from typing import (
Any,
AsyncIterator,
Expand All @@ -30,6 +31,10 @@
import pydash

from ..utils.helper import merge, MergeOptions
from ..utils.reasoning import (
get_reasoning_content,
is_thinking_enabled_from_env,
)
from .model import (
AgentEvent,
AgentRequest,
Expand Down Expand Up @@ -60,6 +65,14 @@ class TextState:
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))


@dataclass
class ReasoningState:
started: bool = False
message_started: bool = False
phase_id: str = field(default_factory=lambda: str(uuid.uuid4()))
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))


@dataclass
class ToolCallState:
name: str = ""
Expand All @@ -72,6 +85,7 @@ class ToolCallState:
@dataclass
class StreamStateMachine:
text: TextState = field(default_factory=TextState)
reasoning: ReasoningState = field(default_factory=ReasoningState)
tool_call_states: Dict[str, ToolCallState] = field(default_factory=dict)
tool_result_chunks: Dict[str, List[str]] = field(default_factory=dict)
run_errored: bool = False
Expand Down Expand Up @@ -121,6 +135,43 @@ def cache_tool_result_chunk(self, tool_id: str, delta: str) -> None:
def pop_tool_result_chunks(self, tool_id: str) -> str:
return "".join(self.tool_result_chunks.pop(tool_id, []))

def ensure_reasoning_started(self) -> Iterator[str]:
if not self.reasoning.started:
yield _encode_reasoning_event(
"REASONING_START",
messageId=self.reasoning.phase_id,
)
self.reasoning.started = True
if not self.reasoning.message_started:
yield _encode_reasoning_event(
"REASONING_MESSAGE_START",
messageId=self.reasoning.message_id,
role="reasoning",
)
self.reasoning.message_started = True

def end_reasoning_if_open(self) -> Iterator[str]:
if self.reasoning.message_started:
yield _encode_reasoning_event(
"REASONING_MESSAGE_END",
messageId=self.reasoning.message_id,
)
self.reasoning.message_started = False
if self.reasoning.started:
yield _encode_reasoning_event(
"REASONING_END",
messageId=self.reasoning.phase_id,
)
self.reasoning = ReasoningState()


def _encode_reasoning_event(event_type: str, **payload: Any) -> str:
return (
"data: "
+ json.dumps({"type": event_type, **payload}, ensure_ascii=False)
+ "\n\n"
)


class AGUIProtocolHandler(BaseProtocolHandler):
"""AG-UI 协议处理器
Expand Down Expand Up @@ -376,6 +427,10 @@ async def _format_stream(
if state.run_errored:
return

# 结束未结束的 reasoning 消息
for sse_data in state.end_reasoning_if_open():
yield sse_data

# 结束所有未结束的工具调用
for sse_data in state.end_all_tools(self._encoder):
yield sse_data
Expand All @@ -399,8 +454,6 @@ def _process_event_with_boundaries(
state: StreamStateMachine,
) -> Iterator[str]:
"""处理事件并注入边界事件"""
import json

from ag_ui.core import CustomEvent as AguiCustomEvent
from ag_ui.core import (
RunErrorEvent,
Expand All @@ -413,6 +466,8 @@ def _process_event_with_boundaries(
ToolCallStartEvent,
)

thinking_enabled = is_thinking_enabled_from_env()

# RAW 事件直接透传
if event.event == EventType.RAW:
raw_data = event.data.get("raw", "")
Expand All @@ -422,9 +477,46 @@ def _process_event_with_boundaries(
yield raw_data
return

if event.event == EventType.REASONING:
if thinking_enabled:
reasoning_content = (
event.data.get("delta")
or get_reasoning_content(event.data)
or ""
)
if reasoning_content:
for sse_data in state.end_text_if_open(self._encoder):
yield sse_data
for sse_data in state.end_all_tools(self._encoder):
yield sse_data
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
"REASONING_MESSAGE_CONTENT",
messageId=state.reasoning.message_id,
delta=reasoning_content,
)
return

# TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
# AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
if event.event == EventType.TEXT:
addition = self._strip_reasoning_from_addition(
event.addition, thinking_enabled
)
addition_reasoning = get_reasoning_content(event.addition or {})
if thinking_enabled and addition_reasoning:
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
"REASONING_MESSAGE_CONTENT",
messageId=state.reasoning.message_id,
delta=addition_reasoning,
)

for sse_data in state.end_reasoning_if_open():
yield sse_data

for sse_data in state.end_all_tools(self._encoder):
yield sse_data

Expand All @@ -435,13 +527,13 @@ def _process_event_with_boundaries(
message_id=state.text.message_id,
delta=event.data.get("delta", ""),
)
if event.addition:
if addition:
event_dict = agui_event.model_dump(
by_alias=True, exclude_none=True
)
event_dict = self._apply_addition(
event_dict,
event.addition,
addition,
event.addition_merge_options,
)
json_str = json.dumps(event_dict, ensure_ascii=False)
Expand All @@ -455,6 +547,9 @@ def _process_event_with_boundaries(
tool_id = event.data.get("id", "")
tool_name = event.data.get("name", "")

for sse_data in state.end_reasoning_if_open():
yield sse_data

for sse_data in state.end_text_if_open(self._encoder):
yield sse_data

Expand Down Expand Up @@ -491,6 +586,9 @@ def _process_event_with_boundaries(
tool_name = event.data.get("name", "")
tool_args = event.data.get("args", "")

for sse_data in state.end_reasoning_if_open():
yield sse_data

for sse_data in state.end_text_if_open(self._encoder):
yield sse_data

Expand Down Expand Up @@ -541,6 +639,9 @@ def _process_event_with_boundaries(
timeout = event.data.get("timeout")
schema = event.data.get("schema")

for sse_data in state.end_reasoning_if_open():
yield sse_data

for sse_data in state.end_text_if_open(self._encoder):
yield sse_data

Expand Down Expand Up @@ -601,6 +702,9 @@ def _process_event_with_boundaries(
tool_id = event.data.get("id", "")
tool_name = event.data.get("name", "")

for sse_data in state.end_reasoning_if_open():
yield sse_data

for sse_data in state.end_text_if_open(self._encoder):
yield sse_data

Expand Down Expand Up @@ -767,6 +871,29 @@ def _apply_addition(

return merge(event_data, addition, **(merge_options or {}))

def _strip_reasoning_from_addition(
self,
addition: Optional[Dict[str, Any]],
thinking_enabled: bool,
) -> Optional[Dict[str, Any]]:
if not addition:
return addition

stripped = dict(addition)
stripped.pop("reasoning_content", None)
additional_kwargs = stripped.get("additional_kwargs")
if isinstance(additional_kwargs, dict):
additional_kwargs = dict(additional_kwargs)
additional_kwargs.pop("reasoning_content", None)
if additional_kwargs:
stripped["additional_kwargs"] = additional_kwargs
else:
stripped.pop("additional_kwargs", None)

if not thinking_enabled:
return stripped
return stripped or None

async def _error_stream(self, message: str) -> AsyncIterator[str]:
"""生成错误事件流

Expand Down
Loading
Loading