187 lines
5.5 KiB
Python
187 lines
5.5 KiB
Python
|
|
"""LLM Adapter - 将现有 LLM 适配到 XBot 接口"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any, Optional
|
|||
|
|
|
|||
|
|
from app.agent.llm.factory import LLMFactory
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ToolCallRequest:
|
|||
|
|
"""A tool call request from the LLM."""
|
|||
|
|
id: str
|
|||
|
|
name: str
|
|||
|
|
arguments: dict[str, Any]
|
|||
|
|
|
|||
|
|
def to_openai_tool_call(self) -> dict[str, Any]:
|
|||
|
|
return {
|
|||
|
|
"id": self.id,
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": self.name,
|
|||
|
|
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class LLMResponse:
|
|||
|
|
"""Response from an LLM provider."""
|
|||
|
|
content: str | None
|
|||
|
|
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
|||
|
|
finish_reason: str = "stop"
|
|||
|
|
usage: dict[str, int] = field(default_factory=dict)
|
|||
|
|
reasoning_content: str | None = None
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def has_tool_calls(self) -> bool:
|
|||
|
|
return len(self.tool_calls) > 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
class XBotLLMAdapter:
|
|||
|
|
"""
|
|||
|
|
适配器:将现有 LLM 适配到 XBot 的 LLMProvider 接口
|
|||
|
|
|
|||
|
|
封装 LLMFactory 创建的 LLM,使其符合 nanobot 风格的接口:
|
|||
|
|
- chat_with_retry(messages, tools, model) -> LLMResponse
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
provider: str,
|
|||
|
|
model_name: str,
|
|||
|
|
api_key: Optional[str] = None,
|
|||
|
|
base_url: Optional[str] = None,
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
max_tokens: int = 4096,
|
|||
|
|
):
|
|||
|
|
self.provider_name = provider
|
|||
|
|
self.model = model_name
|
|||
|
|
self.temperature = temperature
|
|||
|
|
self.max_tokens = max_tokens
|
|||
|
|
|
|||
|
|
# 创建底层 LLM
|
|||
|
|
self._llm = LLMFactory.create(provider, model_name, api_key, base_url)
|
|||
|
|
|
|||
|
|
# 检查是否支持 tool calling
|
|||
|
|
self._supports_tools = self._check_tool_support()
|
|||
|
|
|
|||
|
|
def _check_tool_support(self) -> bool:
|
|||
|
|
"""检查模型是否支持 tool calling"""
|
|||
|
|
# GPT-4, Claude 支持 tool calling
|
|||
|
|
# 简单的判断逻辑
|
|||
|
|
model_lower = self.model.lower()
|
|||
|
|
if "gpt-4" in model_lower or "claude" in model_lower:
|
|||
|
|
return True
|
|||
|
|
return True # 默认支持
|
|||
|
|
|
|||
|
|
async def chat_with_retry(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict[str, Any]],
|
|||
|
|
tools: list[dict[str, Any]] | None = None,
|
|||
|
|
model: str | None = None,
|
|||
|
|
max_tokens: int | None = None,
|
|||
|
|
temperature: float | None = None,
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
"""
|
|||
|
|
发送聊天请求(支持 tool calling)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
tools: 工具定义列表
|
|||
|
|
model: 模型名称(可选)
|
|||
|
|
max_tokens: 最大 tokens(可选)
|
|||
|
|
temperature: 温度(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
LLMResponse: 包含内容和/或工具调用
|
|||
|
|
"""
|
|||
|
|
model = model or self.model
|
|||
|
|
max_tokens = max_tokens or self.max_tokens
|
|||
|
|
temperature = temperature or self.temperature
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 使用流式调用来获取完整响应
|
|||
|
|
response = await self._llm.client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=messages,
|
|||
|
|
tools=tools,
|
|||
|
|
temperature=temperature,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
message = response.choices[0].message
|
|||
|
|
|
|||
|
|
# 检查是否有 tool calls
|
|||
|
|
if message.tool_calls and tools:
|
|||
|
|
tool_calls = []
|
|||
|
|
for tc in message.tool_calls:
|
|||
|
|
tool_calls.append(ToolCallRequest(
|
|||
|
|
id=tc.id,
|
|||
|
|
name=tc.function.name,
|
|||
|
|
arguments=json.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return LLMResponse(
|
|||
|
|
content=message.content,
|
|||
|
|
tool_calls=tool_calls,
|
|||
|
|
finish_reason="tool_calls",
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
return LLMResponse(
|
|||
|
|
content=message.content or "",
|
|||
|
|
finish_reason="stop",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return LLMResponse(
|
|||
|
|
content=f"Error calling LLM: {str(e)}",
|
|||
|
|
finish_reason="error",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def chat(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict[str, Any]],
|
|||
|
|
tools: list[dict[str, Any]] | None = None,
|
|||
|
|
model: str | None = None,
|
|||
|
|
max_tokens: int = 4096,
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
"""简化的 chat 方法"""
|
|||
|
|
return await self.chat_with_retry(
|
|||
|
|
messages=messages,
|
|||
|
|
tools=tools,
|
|||
|
|
model=model,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
temperature=temperature,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def chat_stream(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict[str, Any]],
|
|||
|
|
tools: list[dict[str, Any]] | None = None,
|
|||
|
|
model: str | None = None,
|
|||
|
|
max_tokens: int = 4096,
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
):
|
|||
|
|
"""流式聊天"""
|
|||
|
|
model = model or self.model
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = await self._llm.client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=messages,
|
|||
|
|
tools=tools,
|
|||
|
|
temperature=temperature,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
stream=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async for chunk in response:
|
|||
|
|
if chunk.choices[0].delta.content:
|
|||
|
|
yield chunk.choices[0].delta.content
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
yield f"Error: {str(e)}"
|