Files
X-Agents/agent/app/xbot/adapter.py

187 lines
5.5 KiB
Python
Raw Normal View History

"""LLM Adapter - 将现有 LLM 适配到 XBot 接口"""
import json
from dataclasses import dataclass, field
from typing import Any, Optional
from app.agent.llm.factory import LLMFactory
@dataclass
class ToolCallRequest:
"""A tool call request from the LLM."""
id: str
name: str
arguments: dict[str, Any]
def to_openai_tool_call(self) -> dict[str, Any]:
return {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
@dataclass
class LLMResponse:
"""Response from an LLM provider."""
content: str | None
tool_calls: list[ToolCallRequest] = field(default_factory=list)
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None
@property
def has_tool_calls(self) -> bool:
return len(self.tool_calls) > 0
class XBotLLMAdapter:
"""
适配器将现有 LLM 适配到 XBot LLMProvider 接口
封装 LLMFactory 创建的 LLM使其符合 nanobot 风格的接口
- chat_with_retry(messages, tools, model) -> LLMResponse
"""
def __init__(
self,
provider: str,
model_name: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 4096,
):
self.provider_name = provider
self.model = model_name
self.temperature = temperature
self.max_tokens = max_tokens
# 创建底层 LLM
self._llm = LLMFactory.create(provider, model_name, api_key, base_url)
# 检查是否支持 tool calling
self._supports_tools = self._check_tool_support()
def _check_tool_support(self) -> bool:
"""检查模型是否支持 tool calling"""
# GPT-4, Claude 支持 tool calling
# 简单的判断逻辑
model_lower = self.model.lower()
if "gpt-4" in model_lower or "claude" in model_lower:
return True
return True # 默认支持
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
) -> LLMResponse:
"""
发送聊天请求支持 tool calling
Args:
messages: 消息列表
tools: 工具定义列表
model: 模型名称可选
max_tokens: 最大 tokens可选
temperature: 温度可选
Returns:
LLMResponse: 包含内容和/或工具调用
"""
model = model or self.model
max_tokens = max_tokens or self.max_tokens
temperature = temperature or self.temperature
try:
# 使用流式调用来获取完整响应
response = await self._llm.client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
)
message = response.choices[0].message
# 检查是否有 tool calls
if message.tool_calls and tools:
tool_calls = []
for tc in message.tool_calls:
tool_calls.append(ToolCallRequest(
id=tc.id,
name=tc.function.name,
arguments=json.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
))
return LLMResponse(
content=message.content,
tool_calls=tool_calls,
finish_reason="tool_calls",
)
else:
return LLMResponse(
content=message.content or "",
finish_reason="stop",
)
except Exception as e:
return LLMResponse(
content=f"Error calling LLM: {str(e)}",
finish_reason="error",
)
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
"""简化的 chat 方法"""
return await self.chat_with_retry(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
)
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
):
"""流式聊天"""
model = model or self.model
try:
response = await self._llm.client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
yield f"Error: {str(e)}"