257 lines
7.0 KiB
Python
257 lines
7.0 KiB
Python
|
|
"""XBot Agent - 封装 nanobot 核心能力的 Agent"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any, Optional
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
from .loop import AgentLoop
|
|||
|
|
from .memory import MemoryConsolidator
|
|||
|
|
from .session import SessionManager
|
|||
|
|
from .adapter import XBotLLMAdapter, LLMResponse
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SimpleToolRegistry:
|
|||
|
|
"""简单的工具注册表"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._tools: dict[str, Any] = {}
|
|||
|
|
|
|||
|
|
def register(self, name: str, func: Any, description: str = "") -> None:
|
|||
|
|
"""注册一个工具"""
|
|||
|
|
self._tools[name] = {
|
|||
|
|
"function": func,
|
|||
|
|
"description": description,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def get_definitions(self) -> list[dict]:
|
|||
|
|
"""获取工具定义列表"""
|
|||
|
|
tools = []
|
|||
|
|
for name, tool in self._tools.items():
|
|||
|
|
tools.append({
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": name,
|
|||
|
|
"description": tool.get("description", ""),
|
|||
|
|
"parameters": {
|
|||
|
|
"type": "object",
|
|||
|
|
"properties": {},
|
|||
|
|
"required": [],
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
return tools
|
|||
|
|
|
|||
|
|
def get(self, name: str) -> Optional[Any]:
|
|||
|
|
"""获取工具"""
|
|||
|
|
return self._tools.get(name)
|
|||
|
|
|
|||
|
|
async def execute(self, name: str, arguments: dict) -> Any:
|
|||
|
|
"""执行工具"""
|
|||
|
|
tool = self._tools.get(name)
|
|||
|
|
if not tool:
|
|||
|
|
return f"Tool {name} not found"
|
|||
|
|
|
|||
|
|
func = tool.get("function")
|
|||
|
|
if not func:
|
|||
|
|
return f"Tool {name} has no function"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if callable(func):
|
|||
|
|
return await func(**arguments) if hasattr(func, '__await__') else func(**arguments)
|
|||
|
|
return "Tool function is not callable"
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"Tool execution error: {str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class XBotAgent:
|
|||
|
|
"""
|
|||
|
|
XBot Agent - 基于 nanobot 核心的 Agent 实现
|
|||
|
|
|
|||
|
|
特性:
|
|||
|
|
- 多轮 tool-calling 对话
|
|||
|
|
- 自动内存压缩
|
|||
|
|
- 会话历史持久化
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
name: str,
|
|||
|
|
role_description: str,
|
|||
|
|
provider: str = "openai",
|
|||
|
|
model: str = "gpt-4",
|
|||
|
|
api_key: Optional[str] = None,
|
|||
|
|
base_url: Optional[str] = None,
|
|||
|
|
workspace: Optional[Path] = None,
|
|||
|
|
context_window_tokens: int = 200000,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化 XBot Agent
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
name: Agent 名称
|
|||
|
|
role_description: Agent 角色描述
|
|||
|
|
provider: LLM 提供商
|
|||
|
|
model: 模型名称
|
|||
|
|
api_key: API Key
|
|||
|
|
base_url: Base URL
|
|||
|
|
workspace: 工作目录(用于存储会话和记忆)
|
|||
|
|
context_window_tokens: 上下文窗口大小
|
|||
|
|
"""
|
|||
|
|
self.name = name
|
|||
|
|
self.role_description = role_description
|
|||
|
|
|
|||
|
|
# 创建工作目录
|
|||
|
|
if workspace is None:
|
|||
|
|
workspace = Path(os.getenv("XAGENT_WORKSPACE", "./xbot_workspace"))
|
|||
|
|
self.workspace = workspace
|
|||
|
|
self.workspace.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 创建 LLM 适配器
|
|||
|
|
self.provider = XBotLLMAdapter(
|
|||
|
|
provider=provider,
|
|||
|
|
model_name=model,
|
|||
|
|
api_key=api_key,
|
|||
|
|
base_url=base_url,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建工具注册表
|
|||
|
|
self.tools = SimpleToolRegistry()
|
|||
|
|
self._register_default_tools()
|
|||
|
|
|
|||
|
|
# 创建 Agent Loop
|
|||
|
|
self.agent_loop = AgentLoop(
|
|||
|
|
provider=self.provider,
|
|||
|
|
model=model,
|
|||
|
|
tools=self.tools,
|
|||
|
|
max_iterations=50,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建会话管理器
|
|||
|
|
self.sessions = SessionManager(self.workspace)
|
|||
|
|
|
|||
|
|
# 创建内存压缩器
|
|||
|
|
self.memory = MemoryConsolidator(
|
|||
|
|
workspace=self.workspace,
|
|||
|
|
provider=self.provider,
|
|||
|
|
model=model,
|
|||
|
|
sessions=self.sessions,
|
|||
|
|
context_window_tokens=context_window_tokens,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _register_default_tools(self) -> None:
|
|||
|
|
"""注册默认工具"""
|
|||
|
|
# 可以在这里添加默认工具
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def register_tool(
|
|||
|
|
self,
|
|||
|
|
name: str,
|
|||
|
|
func: Any,
|
|||
|
|
description: str = "",
|
|||
|
|
parameters: Optional[dict] = None,
|
|||
|
|
) -> None:
|
|||
|
|
"""注册自定义工具"""
|
|||
|
|
tool_def = {
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": name,
|
|||
|
|
"description": description,
|
|||
|
|
"parameters": parameters or {
|
|||
|
|
"type": "object",
|
|||
|
|
"properties": {},
|
|||
|
|
"required": [],
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
# 存储在 tools 中
|
|||
|
|
self.tools.register(name, func, description)
|
|||
|
|
|
|||
|
|
async def run(
|
|||
|
|
self,
|
|||
|
|
user_input: str,
|
|||
|
|
session_id: str = "default",
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
运行 Agent 对话
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_input: 用户输入
|
|||
|
|
session_id: 会话 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict: 包含 content, tool_calls 等
|
|||
|
|
"""
|
|||
|
|
# 获取或创建会话
|
|||
|
|
session = self.sessions.get_or_create(session_id)
|
|||
|
|
|
|||
|
|
# 构建系统提示
|
|||
|
|
system_prompt = f"""你是 {self.name}。
|
|||
|
|
{self.role_description}
|
|||
|
|
|
|||
|
|
请根据用户的问题回答,并使用 Markdown 格式输出。"""
|
|||
|
|
|
|||
|
|
# 获取历史消息
|
|||
|
|
history = session.get_history(max_messages=50)
|
|||
|
|
|
|||
|
|
# 构建初始消息
|
|||
|
|
initial_messages = history + [
|
|||
|
|
{"role": "user", "content": user_input}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 运行 agent loop
|
|||
|
|
final_content, tools_used, all_messages = await self.agent_loop.run_loop(
|
|||
|
|
initial_messages=initial_messages,
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 保存到会话
|
|||
|
|
for m in all_messages[len(history):]:
|
|||
|
|
session.messages.append(m)
|
|||
|
|
self.sessions.save(session)
|
|||
|
|
|
|||
|
|
# 尝试内存压缩
|
|||
|
|
await self.memory.maybe_consolidate_by_tokens(session)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"content": final_content or "No response",
|
|||
|
|
"tool_calls": tools_used,
|
|||
|
|
"session_id": session_id,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def run_stream(
|
|||
|
|
self,
|
|||
|
|
user_input: str,
|
|||
|
|
session_id: str = "default",
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
运行 Agent 对话(流式输出)
|
|||
|
|
|
|||
|
|
先完整执行 agent loop,最后流式输出结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_input: 用户输入
|
|||
|
|
session_id: 会话 ID
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
str: 流式回复片段
|
|||
|
|
"""
|
|||
|
|
# 先完整执行 agent loop(包含 tool-calling)
|
|||
|
|
result = await self.run(user_input, session_id)
|
|||
|
|
content = result["content"]
|
|||
|
|
|
|||
|
|
# 流式输出结果
|
|||
|
|
for char in content:
|
|||
|
|
yield char
|
|||
|
|
|
|||
|
|
def clear_session(self, session_id: str) -> None:
|
|||
|
|
"""清除会话"""
|
|||
|
|
session = self.sessions.get_or_create(session_id)
|
|||
|
|
session.clear()
|
|||
|
|
self.sessions.save(session)
|
|||
|
|
self.sessions.invalidate(session_id)
|
|||
|
|
|
|||
|
|
def list_sessions(self) -> list[dict]:
|
|||
|
|
"""列出所有会话"""
|
|||
|
|
return self.sessions.list_sessions()
|