refactor: 重构 Agent 模块
- 删除旧的 agent 核心文件 - 新增 supervisor, memory, skills 等模块 - 重构 main.py 服务入口 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,28 +0,0 @@
|
||||
# Python Agent Service Dockerfile
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装 Python 依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY app/ ./app/
|
||||
|
||||
# 创建数据目录
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8081
|
||||
|
||||
# 启动服务
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8081"]
|
||||
@@ -1,192 +1,123 @@
|
||||
"""
|
||||
Agent 核心管理器
|
||||
Agent Core - 单智能体核心
|
||||
"""
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.agent.core.executor import AgentExecutor
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.security.audit import AuditLogger
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from app.agent.memory.manager import MemoryManager
|
||||
from app.agent.skills.router import SkillRouter
|
||||
from app.agent.skills.executor import SkillExecutor
|
||||
from app.agent.llm.factory import LLMFactory
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""Agent 管理器 - 负责加载和管理所有 Agent"""
|
||||
class AgentConfig(BaseModel):
|
||||
"""智能体配置"""
|
||||
id: int
|
||||
name: str
|
||||
role_description: str
|
||||
model_provider: str = "openai"
|
||||
model_name: str = "gpt-4"
|
||||
skills: List[int] = [] # 技能 ID 列表
|
||||
knowledge_base_ids: List[int] = []
|
||||
timeout: int = 60
|
||||
memory_limit: int = 134217728 # 128MB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
# 初始化组件
|
||||
self.llm_factory = LLMFactory(
|
||||
provider=llm_provider,
|
||||
openai_api_key=self.openai_api_key,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.session_manager = SessionManager()
|
||||
self.audit_logger = AuditLogger()
|
||||
class AgentResponse(BaseModel):
|
||||
"""智能体响应"""
|
||||
content: str
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
tokens_used: int = 0
|
||||
duration_ms: int = 0
|
||||
session_id: Optional[str] = None
|
||||
|
||||
# 已加载的 Agent
|
||||
self.agents: dict[str, dict] = {}
|
||||
self.executors: dict[str, AgentExecutor] = {}
|
||||
|
||||
# 注册默认工具
|
||||
self._register_default_tools()
|
||||
class AgentCore:
|
||||
"""单智能体核心类"""
|
||||
|
||||
def _register_default_tools(self):
|
||||
"""注册默认工具"""
|
||||
from app.agent.tools.impl import search, calculator, time_tool
|
||||
from app.agent.tools.impl import sandbox, database, api_client
|
||||
def __init__(self, config: AgentConfig):
|
||||
self.config = config
|
||||
self.llm = LLMFactory.create(config.model_provider, config.model_name)
|
||||
self.memory = MemoryManager(config.id)
|
||||
self.skill_router = SkillRouter(config.skills)
|
||||
self.skill_executor = SkillExecutor()
|
||||
|
||||
# 安全工具 - Safe 级别
|
||||
self.tool_registry.register(
|
||||
name="search",
|
||||
func=search.search_web,
|
||||
description="Search the web for information",
|
||||
security_level="safe"
|
||||
)
|
||||
async def run(self, user_input: str, user_id: int, session_id: str) -> AgentResponse:
|
||||
"""
|
||||
执行智能体对话
|
||||
|
||||
self.tool_registry.register(
|
||||
name="calculator",
|
||||
func=calculator.calculate,
|
||||
description="Perform mathematical calculations",
|
||||
security_level="safe"
|
||||
)
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
user_id: 用户 ID
|
||||
session_id: 会话 ID
|
||||
|
||||
self.tool_registry.register(
|
||||
name="get_current_time",
|
||||
func=time_tool.get_current_time,
|
||||
description="Get current date and time",
|
||||
security_level="safe"
|
||||
)
|
||||
Returns:
|
||||
AgentResponse: 智能体响应
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 需要审核的工具 - Review 级别
|
||||
self.tool_registry.register(
|
||||
name="execute_code",
|
||||
func=sandbox.sandbox.execute,
|
||||
description="Execute code in sandbox (Python/JavaScript)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Code to execute"},
|
||||
"language": {"type": "string", "default": "python"},
|
||||
"timeout": {"type": "integer", "default": 30}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
)
|
||||
try:
|
||||
# 1. 加载记忆
|
||||
context = await self.memory.load_context(user_input, user_id, session_id)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="query_database",
|
||||
func=database.query_data,
|
||||
description="Query database (SELECT only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {"type": "string", "description": "SELECT query"}
|
||||
},
|
||||
"required": ["sql"]
|
||||
}
|
||||
)
|
||||
# 2. 构建 Prompt
|
||||
prompt = self._build_prompt(user_input, context)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="call_api",
|
||||
func=api_client.call_api,
|
||||
description="Call external API (whitelist only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_name": {"type": "string"},
|
||||
"endpoint": {"type": "string"},
|
||||
"params": {"type": "object"}
|
||||
},
|
||||
"required": ["api_name"]
|
||||
}
|
||||
)
|
||||
# 3. LLM 决策
|
||||
decision = await self.llm.decide(prompt)
|
||||
|
||||
async def load_agents(self):
|
||||
"""加载 Agent 配置"""
|
||||
# TODO: 从数据库或配置文件加载
|
||||
# 这里先注册一些示例 Agent
|
||||
# 4. 执行技能(如需)
|
||||
if decision.get('needs_skill'):
|
||||
skill_results = await self._execute_skills(decision.get('tool_calls', []))
|
||||
# 5. 基于结果生成回复
|
||||
final_response = await self.llm.generate(prompt, skill_results)
|
||||
else:
|
||||
final_response = decision.get('response', '')
|
||||
|
||||
self.agents["assistant"] = {
|
||||
"name": "General Assistant",
|
||||
"description": "A general purpose assistant",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"tools": ["search", "calculator", "get_current_time"]
|
||||
}
|
||||
# 6. 保存记忆
|
||||
await self.memory.save(user_input, final_response, user_id, session_id)
|
||||
|
||||
self.agents["coder"] = {
|
||||
"name": "Code Assistant",
|
||||
"description": "Helps with coding tasks",
|
||||
"system_prompt": "You are a helpful coding assistant. You can write, explain, and debug code.",
|
||||
"tools": ["search", "calculator"]
|
||||
}
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 为每个 Agent 创建执行器
|
||||
for agent_id, config in self.agents.items():
|
||||
self.executors[agent_id] = AgentExecutor(
|
||||
agent_id=agent_id,
|
||||
llm_factory=self.llm_factory,
|
||||
tool_registry=self.tool_registry,
|
||||
session_manager=self.session_manager,
|
||||
audit_logger=self.audit_logger,
|
||||
config=config
|
||||
return AgentResponse(
|
||||
content=final_response,
|
||||
tool_calls=decision.get('tool_calls', []),
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
return AgentResponse(
|
||||
content=f"处理请求时发生错误: {str(e)}",
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict = None
|
||||
) -> dict[str, Any]:
|
||||
"""执行 Agent"""
|
||||
if agent_id not in self.executors:
|
||||
raise ValueError(f"Agent '{agent_id}' not found")
|
||||
def _build_prompt(self, user_input: str, context: dict) -> str:
|
||||
"""构建 Prompt"""
|
||||
system_prompt = f"""你是 {self.config.name}。
|
||||
{self.config.role_description}
|
||||
|
||||
executor = self.executors[agent_id]
|
||||
相关记忆:
|
||||
{context.get('summary', '')}
|
||||
|
||||
# 执行
|
||||
result = await executor.run(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
context=context or {}
|
||||
)
|
||||
知识库信息:
|
||||
{context.get('knowledge', '')}
|
||||
|
||||
return result
|
||||
请根据以上上下文回答用户问题。如果需要使用工具,请明确说明。"""
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""列出所有可用工具"""
|
||||
return self.tool_registry.list_tools()
|
||||
return f"{system_prompt}\n\n用户: {user_input}"
|
||||
|
||||
def list_agents(self) -> list[dict]:
|
||||
"""列出所有 Agent"""
|
||||
return [
|
||||
{
|
||||
"id": agent_id,
|
||||
"name": config["name"],
|
||||
"description": config["description"]
|
||||
}
|
||||
for agent_id, config in self.agents.items()
|
||||
]
|
||||
async def _execute_skills(self, skill_decisions: List[Dict]) -> List[Dict]:
|
||||
"""执行技能"""
|
||||
if not skill_decisions:
|
||||
return []
|
||||
|
||||
def get_agent_info(self, agent_id: str) -> Optional[dict]:
|
||||
"""获取 Agent 信息"""
|
||||
if agent_id not in self.agents:
|
||||
return None
|
||||
return self.agents[agent_id]
|
||||
results = []
|
||||
for decision in skill_decisions:
|
||||
result = await self.skill_executor.execute(
|
||||
skill_id=decision.get('skill_id'),
|
||||
params=decision.get('params', {})
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Agent 执行器 - 负责执行 Agent 的核心逻辑
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
"""Agent 执行器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
llm_factory: LLMFactory,
|
||||
tool_registry: ToolRegistry,
|
||||
session_manager: SessionManager,
|
||||
audit_logger: AuditLogger,
|
||||
config: dict
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.llm_factory = llm_factory
|
||||
self.tool_registry = tool_registry
|
||||
self.session_manager = session_manager
|
||||
self.audit_logger = audit_logger
|
||||
self.config = config
|
||||
|
||||
# 获取 LLM
|
||||
self.llm = self.llm_factory.get_llm()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict
|
||||
) -> dict[str, Any]:
|
||||
"""运行 Agent"""
|
||||
tools_used = []
|
||||
|
||||
# 1. 获取会话历史
|
||||
history = self.session_manager.get_history(session_id)
|
||||
|
||||
# 2. 构建消息列表
|
||||
messages = self._build_messages(message, history)
|
||||
|
||||
# 3. 获取可用工具
|
||||
available_tools = self._get_available_tools()
|
||||
|
||||
# 4. 调用 LLM(带工具)
|
||||
try:
|
||||
response = await self.llm.agenerate(
|
||||
messages=messages,
|
||||
tools=available_tools
|
||||
)
|
||||
|
||||
# 检查是否需要调用工具
|
||||
response_message = response.generations[0][0]
|
||||
|
||||
# 如果有工具调用
|
||||
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||
for tool_call in response_message.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
# 记录工具使用
|
||||
tools_used.append(tool_name)
|
||||
|
||||
# 执行工具
|
||||
tool_result = await self._execute_tool(tool_name, tool_args)
|
||||
|
||||
# 添加工具结果到消息
|
||||
messages.append(response_message)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": str(tool_result)
|
||||
})
|
||||
|
||||
# 再次调用 LLM 生成最终响应
|
||||
final_response = await self.llm.agenerate(messages=messages)
|
||||
final_message = final_response.generations[0][0].text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", final_message)
|
||||
|
||||
return {
|
||||
"reply": final_message,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
# 没有工具调用,直接返回
|
||||
reply = response_message.text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", reply)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
self.audit_logger.log(
|
||||
action="agent_error",
|
||||
agent_id=self.agent_id,
|
||||
session_id=session_id,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
raise
|
||||
|
||||
def _build_messages(self, message: str, history: list) -> list:
|
||||
"""构建消息列表"""
|
||||
messages = []
|
||||
|
||||
# 添加系统提示
|
||||
system_prompt = self.config.get("system_prompt", "You are a helpful assistant.")
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 添加历史
|
||||
for msg in history:
|
||||
messages.append(msg)
|
||||
|
||||
# 添加当前消息
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
return messages
|
||||
|
||||
def _get_available_tools(self) -> list:
|
||||
"""获取可用工具定义"""
|
||||
agent_tools = self.config.get("tools", [])
|
||||
tool_defs = []
|
||||
|
||||
for tool_name in agent_tools:
|
||||
tool_def = self.tool_registry.get_tool_definition(tool_name)
|
||||
if tool_def:
|
||||
tool_defs.append(tool_def)
|
||||
|
||||
return tool_defs
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> Any:
|
||||
"""执行工具"""
|
||||
# 安全检查
|
||||
tool_func, metadata = self.tool_registry.get_tool(tool_name)
|
||||
|
||||
# 如果需要审批,抛出异常
|
||||
if metadata.require_approval:
|
||||
raise PermissionError(
|
||||
f"Tool '{tool_name}' requires approval before execution"
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = tool_func(**args)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
@@ -1,62 +1,125 @@
|
||||
"""
|
||||
会话管理器 - 管理 Agent 的会话历史
|
||||
Session Memory - 会话级记忆(Redis 存储)
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
class SessionMemory:
|
||||
"""会话级记忆,Redis 存储"""
|
||||
|
||||
def __init__(self, max_history: int = 10):
|
||||
def __init__(self, agent_id: int, redis_client=None):
|
||||
"""
|
||||
初始化会话管理器
|
||||
初始化会话记忆
|
||||
|
||||
Args:
|
||||
max_history: 每个会话保留的最大历史消息数
|
||||
agent_id: 智能体 ID
|
||||
redis_client: Redis 客户端(可选)
|
||||
"""
|
||||
self.max_history = max_history
|
||||
self.sessions: dict[str, list[dict]] = defaultdict(list)
|
||||
self.metadata: dict[str, dict] = {}
|
||||
self.agent_id = agent_id
|
||||
self.redis = redis_client
|
||||
self.ttl = 3600 * 24 # 24 小时
|
||||
self.summary_threshold = 10 # 多少条消息后生成摘要
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str):
|
||||
"""添加消息到会话"""
|
||||
self.sessions[session_id].append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
def _key(self, user_id: int, session_id: str) -> str:
|
||||
"""生成 Redis Key"""
|
||||
return f"agent:memory:session:{self.agent_id}:{user_id}:{session_id}"
|
||||
|
||||
# 限制历史长度
|
||||
if len(self.sessions[session_id]) > self.max_history:
|
||||
self.sessions[session_id] = self.sessions[session_id][-self.max_history:]
|
||||
async def add(self, user_input: str, response: str, user_id: int, session_id: str):
|
||||
"""
|
||||
添加对话到会话记忆
|
||||
|
||||
def get_history(self, session_id: str) -> list[dict]:
|
||||
"""获取会话历史"""
|
||||
return self.sessions.get(session_id, [])
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
response: 智能体回复
|
||||
user_id: 用户 ID
|
||||
session_id: 会话 ID
|
||||
"""
|
||||
if not self.redis:
|
||||
# 如果没有 Redis,使用内存模拟
|
||||
return self._add_memory(user_input, response, user_id, session_id)
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""清除会话"""
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
if session_id in self.metadata:
|
||||
del self.metadata[session_id]
|
||||
key = self._key(user_id, session_id)
|
||||
|
||||
def set_metadata(self, session_id: str, key: str, value: Any):
|
||||
"""设置会话元数据"""
|
||||
if session_id not in self.metadata:
|
||||
self.metadata[session_id] = {}
|
||||
self.metadata[session_id][key] = value
|
||||
# 获取现有数据
|
||||
data = await self.redis.get(key)
|
||||
messages = json.loads(data) if data else {"messages": [], "summary": ""}
|
||||
|
||||
def get_metadata(self, session_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取会话元数据"""
|
||||
return self.metadata.get(session_id, {}).get(key, default)
|
||||
# 添加新消息
|
||||
messages["messages"].append({"role": "user", "content": user_input})
|
||||
messages["messages"].append({"role": "assistant", "content": response})
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self.sessions.keys())
|
||||
# 定期生成摘要
|
||||
if len(messages["messages"]) >= self.summary_threshold:
|
||||
messages["summary"] = await self._generate_summary(messages["messages"])
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取会话数量"""
|
||||
return len(self.sessions)
|
||||
# 保持消息数量
|
||||
if len(messages["messages"]) > 50:
|
||||
messages["messages"] = messages["messages"][-50:]
|
||||
|
||||
await self.redis.setex(key, self.ttl, json.dumps(messages))
|
||||
|
||||
async def get_summary(self, user_id: int, session_id: str) -> str:
|
||||
"""
|
||||
获取会话摘要
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
str: 会话摘要
|
||||
"""
|
||||
if not self.redis:
|
||||
return self._get_memory_summary(user_id, session_id)
|
||||
|
||||
key = self._key(user_id, session_id)
|
||||
data = await self.redis.get(key)
|
||||
|
||||
if data:
|
||||
messages = json.loads(data)
|
||||
return messages.get("summary", "")
|
||||
return ""
|
||||
|
||||
async def _generate_summary(self, messages: list) -> str:
|
||||
"""
|
||||
生成摘要(简化版)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
str: 摘要
|
||||
"""
|
||||
# 简化:取最后几条消息的要点
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
recent = messages[-6:] # 最近 3 轮
|
||||
summary = f"最近对话包含 {len(messages)//2} 轮交互"
|
||||
|
||||
# TODO: 后续可以使用 LLM 生成更好的摘要
|
||||
return summary
|
||||
|
||||
# === 内存模拟(无 Redis 时使用)===
|
||||
_memory_store = {}
|
||||
|
||||
def _add_memory(self, user_input: str, response: str, user_id: int, session_id: str):
|
||||
"""内存模拟 - 添加"""
|
||||
key = f"{self.agent_id}:{user_id}:{session_id}"
|
||||
if key not in self._memory_store:
|
||||
self._memory_store[key] = {"messages": [], "summary": ""}
|
||||
|
||||
messages = self._memory_store[key]["messages"]
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
if len(messages) >= self.summary_threshold:
|
||||
self._memory_store[key]["summary"] = self._generate_summary(messages)
|
||||
|
||||
def _get_memory_summary(self, user_id: int, session_id: str) -> str:
|
||||
"""内存模拟 - 获取摘要"""
|
||||
key = f"{self.agent_id}:{user_id}:{session_id}"
|
||||
if key in self._memory_store:
|
||||
return self._memory_store[key].get("summary", "")
|
||||
return ""
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"""
|
||||
多智能体系统
|
||||
"""
|
||||
from .types import AgentState, TaskItem, TaskStatus, AgentType, SupervisorDecision, ReviewResult
|
||||
from .prompts import SUPERVISOR_SYSTEM_PROMPT, REVIEW_SYSTEM_PROMPT, RESEARCH_SYSTEM_PROMPT, CODER_SYSTEM_PROMPT, AGGREGATOR_SYSTEM_PROMPT
|
||||
from .supervisor import SupervisorAgent
|
||||
from .graph import create_multi_agent_graph
|
||||
|
||||
__all__ = [
|
||||
"AgentState",
|
||||
"TaskItem",
|
||||
"TaskStatus",
|
||||
"AgentType",
|
||||
"SupervisorDecision",
|
||||
"ReviewResult",
|
||||
"SUPERVISOR_SYSTEM_PROMPT",
|
||||
"REVIEW_SYSTEM_PROMPT",
|
||||
"RESEARCH_SYSTEM_PROMPT",
|
||||
"CODER_SYSTEM_PROMPT",
|
||||
"AGGREGATOR_SYSTEM_PROMPT",
|
||||
"SupervisorAgent",
|
||||
"create_multi_agent_graph",
|
||||
]
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
LangGraph 流程编排
|
||||
"""
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
|
||||
from .types import AgentState, AgentType
|
||||
from .supervisor import SupervisorAgent, ResultAggregator
|
||||
from .workers.research import ResearchWorker
|
||||
from .workers.coder import CoderWorker
|
||||
from .workers.review import ReviewWorker
|
||||
|
||||
|
||||
def create_multi_agent_graph(
|
||||
llm,
|
||||
tool_registry=None,
|
||||
max_iterations: int = 3,
|
||||
max_tasks: int = 10
|
||||
) -> CompiledGraph:
|
||||
"""创建多 Agent 流程图
|
||||
|
||||
Args:
|
||||
llm: 语言模型实例
|
||||
tool_registry: 工具注册表
|
||||
max_iterations: 最大迭代次数
|
||||
max_tasks: 最大任务数
|
||||
|
||||
Returns:
|
||||
CompiledGraph: 编译后的 LangGraph
|
||||
"""
|
||||
|
||||
# 初始化组件
|
||||
supervisor = SupervisorAgent(llm, max_iterations=max_iterations, max_tasks=max_tasks)
|
||||
research_worker = ResearchWorker(llm, tool_registry)
|
||||
coder_worker = CoderWorker(llm, tool_registry)
|
||||
review_worker = ReviewWorker(llm, tool_registry)
|
||||
aggregator = ResultAggregator(llm)
|
||||
|
||||
# 创建图
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
# 添加节点
|
||||
graph.add_node("supervisor", supervisor.create_node())
|
||||
graph.add_node(AgentType.RESEARCH, research_worker.create_node())
|
||||
graph.add_node(AgentType.CODER, coder_worker.create_node())
|
||||
graph.add_node(AgentType.REVIEW, review_worker.create_node())
|
||||
graph.add_node("aggregator", aggregator.create_node())
|
||||
|
||||
# 设置入口点
|
||||
graph.set_entry_point("supervisor")
|
||||
|
||||
# 定义条件边函数
|
||||
def should_continue(state: AgentState) -> str:
|
||||
"""判断是否继续执行"""
|
||||
|
||||
# 获取下一步节点
|
||||
next_node = state.get("next_node", "aggregator")
|
||||
|
||||
# 如果是结束节点
|
||||
if next_node in ["__end__", "aggregator"]:
|
||||
return "aggregator"
|
||||
|
||||
# 如果是 Worker 节点
|
||||
if next_node in [AgentType.RESEARCH, AgentType.CODER, AgentType.REVIEW]:
|
||||
return next_node
|
||||
|
||||
# 如果是 supervisor
|
||||
if next_node == "supervisor":
|
||||
# 检查迭代次数
|
||||
iteration = state.get("iteration", 0)
|
||||
if iteration >= max_iterations:
|
||||
return "aggregator"
|
||||
return "supervisor"
|
||||
|
||||
# 默认进入汇总
|
||||
return "aggregator"
|
||||
|
||||
# 添加条件边:从 supervisor 出来
|
||||
graph.add_conditional_edges(
|
||||
"supervisor",
|
||||
should_continue,
|
||||
{
|
||||
"supervisor": "supervisor",
|
||||
AgentType.RESEARCH: AgentType.RESEARCH,
|
||||
AgentType.CODER: AgentType.CODER,
|
||||
AgentType.REVIEW: AgentType.REVIEW,
|
||||
"aggregator": "aggregator"
|
||||
}
|
||||
)
|
||||
|
||||
# 添加边:Worker -> Review
|
||||
graph.add_edge(AgentType.RESEARCH, AgentType.REVIEW)
|
||||
graph.add_edge(AgentType.CODER, AgentType.REVIEW)
|
||||
|
||||
# 添加条件边:从 Review 出来
|
||||
graph.add_conditional_edges(
|
||||
AgentType.REVIEW,
|
||||
should_continue,
|
||||
{
|
||||
"supervisor": "supervisor",
|
||||
"aggregator": "aggregator"
|
||||
}
|
||||
)
|
||||
|
||||
# 添加边:aggregator -> END
|
||||
graph.add_edge("aggregator", END)
|
||||
|
||||
# 编译图
|
||||
return graph.compile()
|
||||
|
||||
|
||||
def create_simple_graph(llm, tool_registry=None) -> CompiledGraph:
|
||||
"""创建简单的单 Agent 图(不经过 Supervisor)"""
|
||||
|
||||
# 创建图
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
# 直接使用 Coder Worker
|
||||
coder_worker = CoderWorker(llm, tool_registry)
|
||||
|
||||
# 添加节点
|
||||
graph.add_node("coder", coder_worker.create_node())
|
||||
|
||||
# 设置入口
|
||||
graph.set_entry_point("coder")
|
||||
|
||||
# 添加边
|
||||
graph.add_edge("coder", END)
|
||||
|
||||
return graph.compile()
|
||||
@@ -1,223 +0,0 @@
|
||||
"""
|
||||
多智能体系统 - 与现有系统集成
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.agent.memory.session import SessionManager
|
||||
|
||||
from .types import create_initial_state
|
||||
from .graph import create_multi_agent_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiAgentSystem:
|
||||
"""多智能体系统 - 集成现有组件"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
max_iterations: int = 3,
|
||||
max_tasks: int = 10
|
||||
):
|
||||
"""
|
||||
初始化多智能体系统
|
||||
|
||||
Args:
|
||||
llm_provider: LLM 提供商
|
||||
openai_api_key: OpenAI API Key
|
||||
anthropic_api_key: Anthropic API Key
|
||||
max_iterations: 最大迭代次数
|
||||
max_tasks: 最大任务数
|
||||
"""
|
||||
# 初始化 LLM Factory
|
||||
self.llm_factory = LLMFactory(
|
||||
provider=llm_provider,
|
||||
openai_api_key=openai_api_key,
|
||||
anthropic_api_key=anthropic_api_key
|
||||
)
|
||||
|
||||
# 初始化 Tool Registry
|
||||
self.tool_registry = ToolRegistry()
|
||||
self._register_default_tools()
|
||||
|
||||
# 初始化 Session Manager
|
||||
self.session_manager = SessionManager()
|
||||
|
||||
# 配置
|
||||
self.max_iterations = max_iterations
|
||||
self.max_tasks = max_tasks
|
||||
|
||||
# 图实例(延迟初始化)
|
||||
self._graph = None
|
||||
|
||||
def _register_default_tools(self):
|
||||
"""注册默认工具"""
|
||||
try:
|
||||
from app.agent.tools.impl import search, calculator, time_tool
|
||||
|
||||
# 安全工具
|
||||
self.tool_registry.register(
|
||||
name="search",
|
||||
func=search.search_web,
|
||||
description="Search the web for information",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="calculator",
|
||||
func=calculator.calculate,
|
||||
description="Perform mathematical calculations",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="get_current_time",
|
||||
func=time_tool.get_current_time,
|
||||
description="Get current date and time",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
# 执行代码工具
|
||||
try:
|
||||
from app.agent.tools.impl import sandbox
|
||||
self.tool_registry.register(
|
||||
name="execute_code",
|
||||
func=sandbox.sandbox.execute,
|
||||
description="Execute code in sandbox",
|
||||
security_level="review",
|
||||
require_approval=True
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import default tools: {e}")
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
"""获取或创建 LangGraph"""
|
||||
if self._graph is None:
|
||||
llm = self.llm_factory.get_llm()
|
||||
self._graph = create_multi_agent_graph(
|
||||
llm=llm,
|
||||
tool_registry=self.tool_registry,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tasks=self.max_tasks
|
||||
)
|
||||
return self._graph
|
||||
|
||||
async def execute(self, task: str, session_id: str = None) -> dict:
|
||||
"""
|
||||
执行多 Agent 任务
|
||||
|
||||
Args:
|
||||
task: 任务描述
|
||||
session_id: 会话 ID(可选)
|
||||
|
||||
Returns:
|
||||
dict: 执行结果
|
||||
"""
|
||||
# 创建初始状态
|
||||
initial_state = create_initial_state(task, session_id)
|
||||
|
||||
try:
|
||||
# 执行图
|
||||
result = await self.graph.ainvoke(initial_state)
|
||||
|
||||
# 保存到 session
|
||||
if session_id:
|
||||
self.session_manager.add_message(session_id, "user", task)
|
||||
self.session_manager.add_message(
|
||||
session_id,
|
||||
"assistant",
|
||||
result.get("final_output", "")
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.get("status") != "failed",
|
||||
"output": result.get("final_output", ""),
|
||||
"status": result.get("status", "unknown"),
|
||||
"task_plan": result.get("task_plan", []),
|
||||
"results": result.get("results", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Multi-agent execution failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"output": f"执行失败: {str(e)}",
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def execute_simple(self, task: str, session_id: str = None) -> dict:
|
||||
"""
|
||||
执行简单任务(不使用 Supervisor)
|
||||
|
||||
Args:
|
||||
task: 任务描述
|
||||
session_id: 会话 ID(可选)
|
||||
|
||||
Returns:
|
||||
dict: 执行结果
|
||||
"""
|
||||
from .graph import create_simple_graph
|
||||
|
||||
# 创建简单图
|
||||
llm = self.llm_factory.get_llm()
|
||||
simple_graph = create_simple_graph(llm, self.tool_registry)
|
||||
|
||||
# 创建初始状态
|
||||
initial_state = create_initial_state(task, session_id)
|
||||
|
||||
try:
|
||||
# 执行图
|
||||
result = await simple_graph.ainvoke(initial_state)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": result.get("final_output", ""),
|
||||
"status": result.get("status", "completed")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Simple execution failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"output": f"执行失败: {str(e)}",
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""列出所有可用工具"""
|
||||
return self.tool_registry.list_tools()
|
||||
|
||||
|
||||
# 全局实例
|
||||
_global_system: Optional[MultiAgentSystem] = None
|
||||
|
||||
|
||||
def get_multi_agent_system(
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: str = None,
|
||||
anthropic_api_key: str = None,
|
||||
**kwargs
|
||||
) -> MultiAgentSystem:
|
||||
"""获取全局多智能体系统实例"""
|
||||
global _global_system
|
||||
|
||||
if _global_system is None:
|
||||
_global_system = MultiAgentSystem(
|
||||
llm_provider=llm_provider,
|
||||
openai_api_key=openai_api_key,
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return _global_system
|
||||
@@ -1,117 +0,0 @@
|
||||
"""
|
||||
迭代控制器
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class IterationController:
|
||||
"""迭代控制器 - 管理任务执行的迭代"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iterations: int = 3,
|
||||
max_retries_per_task: int = 2
|
||||
):
|
||||
"""
|
||||
初始化迭代控制器
|
||||
|
||||
Args:
|
||||
max_iterations: 全局最大迭代次数
|
||||
max_retries_per_task: 每个任务的最大重试次数
|
||||
"""
|
||||
self.max_iterations = max_iterations
|
||||
self.max_retries_per_task = max_retries_per_task
|
||||
|
||||
def should_continue(
|
||||
self,
|
||||
iteration: int,
|
||||
task_status: str,
|
||||
review_result: Optional[dict] = None
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
判断是否继续迭代
|
||||
|
||||
Args:
|
||||
iteration: 当前迭代次数
|
||||
task_status: 任务状态
|
||||
review_result: 评审结果(可选)
|
||||
|
||||
Returns:
|
||||
(是否继续, 原因)
|
||||
"""
|
||||
# 超过最大迭代次数
|
||||
if iteration >= self.max_iterations:
|
||||
return False, "max_iterations_reached"
|
||||
|
||||
# 任务成功完成
|
||||
if task_status == "completed":
|
||||
if review_result and review_result.get("passed"):
|
||||
return False, "task_completed"
|
||||
elif review_result is None:
|
||||
return False, "task_completed"
|
||||
|
||||
# 任务失败且不可重试
|
||||
if task_status == "failed":
|
||||
if review_result and not review_result.get("retryable", True):
|
||||
return False, "task_failed_non_retryable"
|
||||
|
||||
# 检查重试次数
|
||||
retry_count = review_result.get("retry_count", 0) if review_result else 0
|
||||
if retry_count >= self.max_retries_per_task:
|
||||
return False, "max_retries_reached"
|
||||
|
||||
# 需要重试
|
||||
if review_result:
|
||||
issues = review_result.get("issues", [])
|
||||
if issues and not review_result.get("passed", True):
|
||||
return True, "needs_retry"
|
||||
|
||||
return True, "continue"
|
||||
|
||||
def get_next_action(
|
||||
self,
|
||||
review_result: Optional[dict],
|
||||
current_worker: str
|
||||
) -> str:
|
||||
"""
|
||||
确定下一步动作
|
||||
|
||||
Args:
|
||||
review_result: 评审结果
|
||||
current_worker: 当前执行的 Worker
|
||||
|
||||
Returns:
|
||||
下一个节点名称
|
||||
"""
|
||||
if review_result is None:
|
||||
return "supervisor"
|
||||
|
||||
# 根据评审结果决定下一步
|
||||
if review_result.get("passed"):
|
||||
return "supervisor"
|
||||
|
||||
# 根据问题类型决定下一步
|
||||
issues = review_result.get("issues", [])
|
||||
high_severity = any(i.get("severity") == "high" for i in issues)
|
||||
|
||||
if high_severity:
|
||||
# 严重问题,重新执行相同任务
|
||||
return current_worker
|
||||
else:
|
||||
# 轻微问题,返回 Supervisor
|
||||
return "supervisor"
|
||||
|
||||
def calculate_backoff_delay(self, retry_count: int) -> float:
|
||||
"""
|
||||
计算退避延迟(指数退避)
|
||||
|
||||
Args:
|
||||
retry_count: 重试次数
|
||||
|
||||
Returns:
|
||||
延迟时间(秒)
|
||||
"""
|
||||
base_delay = 1.0
|
||||
max_delay = 30.0
|
||||
delay = min(base_delay * (2 ** retry_count), max_delay)
|
||||
return delay
|
||||
@@ -1,170 +0,0 @@
|
||||
"""
|
||||
多智能体系统 Prompt 模板
|
||||
"""
|
||||
|
||||
# Supervisor System Prompt
|
||||
SUPERVISOR_SYSTEM_PROMPT = """你是一个任务规划专家(Supervisor)。你的职责是将复杂任务分解为可执行的子任务,并分配给合适的执行 Agent。
|
||||
|
||||
## 可用的 Worker Agent
|
||||
- **research**: 信息搜索和调研
|
||||
- **coder**: 代码编写、修改和调试
|
||||
- **review**: 结果检查、质量评审
|
||||
|
||||
## 任务
|
||||
{task}
|
||||
|
||||
## 当前进度
|
||||
{progress}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请按以下步骤执行
|
||||
|
||||
### 步骤 1: 任务分析
|
||||
分析任务的性质,确定需要哪些步骤来完成。
|
||||
|
||||
### 步骤 2: 任务分解
|
||||
将任务分解为独立的子任务。每个子任务应该:
|
||||
- 描述清晰
|
||||
- 可以由单个 Agent 完成
|
||||
- 有明确的完成标准
|
||||
|
||||
### 步骤 3: 分配 Agent
|
||||
为每个子任务选择最合适的执行 Agent。
|
||||
|
||||
### 步骤 4: 确定执行顺序
|
||||
如果有依赖关系,确定正确的执行顺序。
|
||||
|
||||
## 输出格式
|
||||
请以 JSON 格式输出你的决策,包含以下字段:
|
||||
- analysis: 任务分析
|
||||
- task_plan: 任务计划数组,每个元素包含 id, description, assigned_agent
|
||||
- need_aggregation: 是否需要汇总结果
|
||||
- next_worker: 下一个执行的 Worker 名称 (research/coder/review)
|
||||
|
||||
## 注意
|
||||
- 如果任务很简单,可以只分配给一个 Agent
|
||||
- 如果任务需要迭代优化,确保有 review 环节
|
||||
- 考虑任务之间的依赖关系
|
||||
- 使用 "research"/"coder"/"review" 作为 assigned_agent 的值
|
||||
"""
|
||||
|
||||
# Review Worker System Prompt
|
||||
REVIEW_SYSTEM_PROMPT = """你是一个代码和结果评审专家(Reviewer)。你的职责是检查任务执行结果是否符合要求。
|
||||
|
||||
## 原始任务
|
||||
{original_task}
|
||||
|
||||
## 当前任务描述
|
||||
{task_description}
|
||||
|
||||
## 执行结果
|
||||
{execution_result}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 检查标准
|
||||
1. 结果是否完整解决了原始任务?
|
||||
2. 输出格式是否正确?
|
||||
3. 是否存在明显的错误或遗漏?
|
||||
4. 代码是否有潜在问题?
|
||||
5. 是否有安全漏洞或风险?
|
||||
|
||||
## 输出格式
|
||||
请以 JSON 格式输出评审结果:
|
||||
- passed: true/false,是否通过
|
||||
- issues: 问题数组,每个包含 severity(high/medium/low) 和 description
|
||||
- suggestions: 改进建议数组
|
||||
- retryable: true/false,是否可以重试
|
||||
|
||||
## 注意
|
||||
- 如果只有轻微问题,passed 可以为 true
|
||||
- 如果有严重问题,passed 应为 false
|
||||
- 判断是否需要重试,而不是立即失败
|
||||
"""
|
||||
|
||||
# Research Worker System Prompt
|
||||
RESEARCH_SYSTEM_PROMPT = """你是一个信息搜索和调研专家(Researcher)。你的职责是根据任务要求搜集和整理信息。
|
||||
|
||||
## 任务
|
||||
{task}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请执行以下步骤
|
||||
|
||||
### 1. 理解任务
|
||||
明确需要搜集什么信息,信息的用途是什么。
|
||||
|
||||
### 2. 搜索信息
|
||||
使用可用工具搜索相关信息。
|
||||
|
||||
### 3. 整理结果
|
||||
将搜索结果整理成结构化的信息。
|
||||
|
||||
## 输出要求
|
||||
- 提供清晰、结构化的信息整理
|
||||
- 标注信息来源
|
||||
- 如果无法完成任务,说明原因
|
||||
"""
|
||||
|
||||
# Coder Worker System Prompt
|
||||
CODER_SYSTEM_PROMPT = """你是一个代码编写专家(Coder)。你的职责是根据任务要求编写和修改代码。
|
||||
|
||||
## 任务
|
||||
{task}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请执行以下步骤
|
||||
|
||||
### 1. 理解需求
|
||||
明确需要编写什么代码,代码的用途和约束。
|
||||
|
||||
### 2. 编写代码
|
||||
使用合适的编程语言和框架编写代码。
|
||||
|
||||
### 3. 代码检查
|
||||
确保代码语法正确,逻辑合理。
|
||||
|
||||
## 输出要求
|
||||
- 提供完整的、可运行的代码
|
||||
- 包含必要的注释说明
|
||||
- 如果需要执行代码,使用代码执行工具
|
||||
"""
|
||||
|
||||
# Aggregator System Prompt
|
||||
AGGREGATOR_SYSTEM_PROMPT = """你是一个结果汇总专家(Aggregator)。你的职责是将多个子任务的结果汇总成最终输出。
|
||||
|
||||
## 原始任务
|
||||
{original_task}
|
||||
|
||||
## 任务计划
|
||||
{task_plan}
|
||||
|
||||
## 执行结果
|
||||
{results}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请执行以下步骤
|
||||
|
||||
### 1. 分析结果
|
||||
分析每个子任务的执行结果。
|
||||
|
||||
### 2. 识别关键信息
|
||||
从结果中提取关键信息。
|
||||
|
||||
### 3. 汇总输出
|
||||
将所有结果整合成一个连贯的最终输出。
|
||||
|
||||
## 输出要求
|
||||
- 提供清晰、完整的最终结果
|
||||
- 标注每个部分的来源
|
||||
- 确保结果解决了原始任务
|
||||
"""
|
||||
@@ -1,262 +0,0 @@
|
||||
"""
|
||||
Supervisor Agent - 负责任务规划和分发
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from .types import AgentState, TaskItem, AgentType, SupervisorDecision
|
||||
from .prompts import SUPERVISOR_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class SupervisorAgent:
|
||||
"""Supervisor Agent - 负责任务规划和分发"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
max_iterations: int = 3,
|
||||
max_tasks: int = 10
|
||||
):
|
||||
self.llm = llm
|
||||
self.max_iterations = max_iterations
|
||||
self.max_tasks = max_tasks
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
return self._supervisor_node
|
||||
|
||||
async def _supervisor_node(self, state: AgentState) -> dict:
|
||||
"""Supervisor 节点逻辑"""
|
||||
|
||||
# 首次调用:分析任务并生成计划
|
||||
if not state.get("task_plan"):
|
||||
decision = await self._plan_tasks(
|
||||
task=state["original_task"],
|
||||
progress="这是任务的开始",
|
||||
context=state.get("shared_context", {})
|
||||
)
|
||||
|
||||
return {
|
||||
"task_plan": decision.task_plan,
|
||||
"next_node": decision.next_worker,
|
||||
"current_task_index": 0,
|
||||
"shared_context": {
|
||||
**state.get("shared_context", {}),
|
||||
"task_analysis": decision.analysis
|
||||
}
|
||||
}
|
||||
|
||||
# 非首次调用:检查任务状态,决定下一步
|
||||
current_task_index = state.get("current_task_index", 0)
|
||||
task_plan = state.get("task_plan", [])
|
||||
|
||||
# 获取当前任务
|
||||
if current_task_index >= len(task_plan):
|
||||
# 所有任务完成,进入汇总
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
current_task = task_plan[current_task_index]
|
||||
|
||||
# 检查当前任务状态
|
||||
if current_task.status == "completed":
|
||||
# 当前任务完成,检查是否还有更多任务
|
||||
if current_task_index + 1 < len(task_plan):
|
||||
next_index = current_task_index + 1
|
||||
next_task = task_plan[next_index]
|
||||
return {
|
||||
"current_task_index": next_index,
|
||||
"next_node": next_task.assigned_agent,
|
||||
"iteration": state.get("iteration", 0),
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
else:
|
||||
# 所有任务完成,进入汇总
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
elif current_task.status == "failed":
|
||||
# 任务失败,检查是否超过最大重试
|
||||
if current_task.retry_count >= self.max_iterations:
|
||||
# 超过最大重试,进入汇总(标记失败)
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"status": "failed",
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
else:
|
||||
# 重试当前任务
|
||||
return {
|
||||
"next_node": current_task.assigned_agent,
|
||||
"iteration": state.get("iteration", 0) + 1,
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
elif current_task.status == "needs_retry":
|
||||
# 需要重试(来自 review)
|
||||
return {
|
||||
"next_node": current_task.assigned_agent,
|
||||
"iteration": state.get("iteration", 0) + 1,
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
# 默认继续执行
|
||||
return {
|
||||
"next_node": state.get("next_node", "aggregator"),
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
async def _plan_tasks(self, task: str, progress: str, context: dict) -> SupervisorDecision:
|
||||
"""调用 LLM 生成任务计划"""
|
||||
|
||||
# 格式化 prompt
|
||||
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else "无"
|
||||
prompt = SUPERVISOR_SYSTEM_PROMPT.format(
|
||||
task=task,
|
||||
progress=progress,
|
||||
context=context_str
|
||||
)
|
||||
|
||||
# 调用 LLM
|
||||
response = await self.llm.ainvoke([
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content="请分析任务并制定执行计划。")
|
||||
])
|
||||
|
||||
# 解析 LLM 输出
|
||||
decision = self._parse_response(response.content, task)
|
||||
|
||||
return decision
|
||||
|
||||
def _parse_response(self, response: str, original_task: str) -> SupervisorDecision:
|
||||
"""解析 LLM 响应为结构化决策"""
|
||||
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("No JSON found")
|
||||
|
||||
# 解析任务计划
|
||||
task_plan = []
|
||||
for i, item in enumerate(data.get("task_plan", [])):
|
||||
task = TaskItem(
|
||||
id=item.get("id", f"task_{i+1}"),
|
||||
description=item.get("description", ""),
|
||||
assigned_agent=AgentType(item.get("assigned_agent", "coder")),
|
||||
status="pending"
|
||||
)
|
||||
task_plan.append(task)
|
||||
|
||||
# 确定下一个 Worker
|
||||
next_worker = data.get("next_worker", "research")
|
||||
if isinstance(next_worker, dict):
|
||||
next_worker = next_worker.get("assigned_agent", "research")
|
||||
|
||||
return SupervisorDecision(
|
||||
analysis=data.get("analysis", "任务分析"),
|
||||
task_plan=task_plan,
|
||||
need_aggregation=data.get("need_aggregation", True),
|
||||
next_worker=AgentType(next_worker)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 解析失败,创建默认计划
|
||||
return self._create_default_plan(original_task)
|
||||
|
||||
def _create_default_plan(self, task: str) -> SupervisorDecision:
|
||||
"""创建默认任务计划"""
|
||||
|
||||
task_lower = task.lower()
|
||||
|
||||
# 根据任务关键词判断
|
||||
if any(keyword in task_lower for keyword in ["搜索", "查找", "调研", "研究", "research", "search"]):
|
||||
assigned_agent = AgentType.RESEARCH
|
||||
elif any(keyword in task_lower for keyword in ["代码", "写", "开发", "code", "program", "写代码"]):
|
||||
assigned_agent = AgentType.CODER
|
||||
else:
|
||||
assigned_agent = AgentType.CODER
|
||||
|
||||
# 创建默认任务
|
||||
task_item = TaskItem(
|
||||
id="task_1",
|
||||
description=task,
|
||||
assigned_agent=assigned_agent,
|
||||
status="pending"
|
||||
)
|
||||
|
||||
return SupervisorDecision(
|
||||
analysis="简单任务,直接分配给合适的 Agent 执行",
|
||||
task_plan=[task_item],
|
||||
need_aggregation=True,
|
||||
next_worker=assigned_agent
|
||||
)
|
||||
|
||||
|
||||
class ResultAggregator:
|
||||
"""结果聚合器 - 汇总多个任务的结果"""
|
||||
|
||||
def __init__(self, llm: BaseChatModel):
|
||||
self.llm = llm
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
return self._aggregate_node
|
||||
|
||||
async def _aggregate_node(self, state: AgentState) -> dict:
|
||||
"""聚合节点逻辑"""
|
||||
|
||||
# 准备任务计划和结果
|
||||
task_plan = state.get("task_plan", [])
|
||||
results = state.get("results", {})
|
||||
original_task = state.get("original_task", "")
|
||||
|
||||
# 构建任务描述
|
||||
task_descriptions = []
|
||||
for task in task_plan:
|
||||
task_descriptions.append(f"- {task.id}: {task.description} -> {task.status}")
|
||||
|
||||
# 构建结果描述
|
||||
result_items = []
|
||||
for task_id, result in results.items():
|
||||
if isinstance(result, dict):
|
||||
content = result.get("content", str(result))
|
||||
else:
|
||||
content = str(result)
|
||||
result_items.append(f"## {task_id}\n{content}")
|
||||
|
||||
# 调用 LLM 汇总结果
|
||||
from .prompts import AGGREGATOR_SYSTEM_PROMPT
|
||||
|
||||
context_str = json.dumps(state.get("shared_context", {}), ensure_ascii=False, indent=2)
|
||||
|
||||
prompt = AGGREGATOR_SYSTEM_PROMPT.format(
|
||||
original_task=original_task,
|
||||
task_plan="\n".join(task_descriptions),
|
||||
results="\n\n".join(result_items) if result_items else "无结果",
|
||||
context=context_str
|
||||
)
|
||||
|
||||
response = await self.llm.ainvoke([
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content="请汇总以上结果,给出最终输出。")
|
||||
])
|
||||
|
||||
# 检查是否有失败的任务
|
||||
has_failed = any(task.status == "failed" for task in task_plan)
|
||||
|
||||
return {
|
||||
"final_output": response.content,
|
||||
"status": "failed" if has_failed else "completed",
|
||||
"next_node": "__end__"
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
多智能体系统数据类型定义
|
||||
"""
|
||||
from typing import TypedDict, Annotated, Optional, Literal
|
||||
from operator import add
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
NEEDS_RETRY = "needs_retry"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
"""Agent 类型"""
|
||||
SUPERVISOR = "supervisor"
|
||||
RESEARCH = "research"
|
||||
CODER = "coder"
|
||||
REVIEW = "review"
|
||||
AGGREGATOR = "aggregator"
|
||||
|
||||
|
||||
class TaskItem(BaseModel):
|
||||
"""单个任务项"""
|
||||
id: str = Field(..., description="任务唯一标识")
|
||||
description: str = Field(..., description="任务描述")
|
||||
assigned_agent: AgentType = Field(..., description="分配的 Agent 类型")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="任务状态")
|
||||
result: Optional[dict] = Field(default=None, description="任务执行结果")
|
||||
error: Optional[str] = Field(default=None, description="错误信息")
|
||||
retry_count: int = Field(default=0, description="重试次数")
|
||||
|
||||
|
||||
class SupervisorDecision(BaseModel):
|
||||
"""Supervisor 的结构化决策"""
|
||||
analysis: str = Field(..., description="任务分析")
|
||||
task_plan: list[TaskItem] = Field(..., description="任务计划")
|
||||
need_aggregation: bool = Field(default=True, description="是否需要汇总")
|
||||
next_worker: AgentType = Field(..., description="下一个执行的 Worker")
|
||||
|
||||
|
||||
class ReviewResult(BaseModel):
|
||||
"""Review 结果"""
|
||||
passed: bool = Field(..., description="是否通过")
|
||||
issues: list[dict] = Field(default_factory=list, description="问题列表")
|
||||
suggestions: list[str] = Field(default_factory=list, description="改进建议")
|
||||
retryable: bool = Field(default=True, description="是否可重试")
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
"""贯穿整个图的 Agent 状态"""
|
||||
# 用户输入
|
||||
original_task: str # 原始任务描述
|
||||
session_id: Optional[str] # 会话 ID
|
||||
|
||||
# 任务规划
|
||||
task_plan: list[TaskItem] # 分解后的任务列表
|
||||
current_task_index: int # 当前执行的任务索引
|
||||
|
||||
# 执行结果
|
||||
results: dict # {task_id: result}
|
||||
|
||||
# 流程控制
|
||||
iteration: int # 当前迭代次数
|
||||
next_node: str # 下一个节点名称
|
||||
|
||||
# 共享上下文
|
||||
shared_context: dict # Agent 间共享的数据
|
||||
|
||||
# 最终输出
|
||||
final_output: str
|
||||
status: Literal["running", "completed", "failed"] # 运行状态
|
||||
|
||||
|
||||
def create_initial_state(task: str, session_id: str = None) -> AgentState:
|
||||
"""创建初始状态"""
|
||||
return {
|
||||
"original_task": task,
|
||||
"session_id": session_id,
|
||||
"task_plan": [],
|
||||
"current_task_index": 0,
|
||||
"results": {},
|
||||
"iteration": 0,
|
||||
"next_node": "supervisor",
|
||||
"shared_context": {},
|
||||
"final_output": "",
|
||||
"status": "running"
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
"""
|
||||
Worker Agents
|
||||
"""
|
||||
from .base import BaseWorker
|
||||
from .research import ResearchWorker
|
||||
from .coder import CoderWorker
|
||||
from .review import ReviewWorker
|
||||
|
||||
__all__ = [
|
||||
"BaseWorker",
|
||||
"ResearchWorker",
|
||||
"CoderWorker",
|
||||
"ReviewWorker",
|
||||
]
|
||||
@@ -1,138 +0,0 @@
|
||||
"""
|
||||
Worker Agent 基类
|
||||
"""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from ..types import AgentState, TaskItem, TaskStatus
|
||||
|
||||
|
||||
class BaseWorker(ABC):
|
||||
"""Worker Agent 基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
name: str,
|
||||
system_prompt: str,
|
||||
tools: list = None,
|
||||
tool_registry=None
|
||||
):
|
||||
self.llm = llm
|
||||
self.name = name
|
||||
self.system_prompt = system_prompt
|
||||
self.tools = tools or []
|
||||
self.tool_registry = tool_registry
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""
|
||||
执行任务
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool,
|
||||
"content": str,
|
||||
"context": dict, # 更新共享上下文
|
||||
"error": str (optional)
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
async def node(state: AgentState) -> dict:
|
||||
task_index = state.get("current_task_index", 0)
|
||||
task_plan = state.get("task_plan", [])
|
||||
|
||||
if task_index >= len(task_plan):
|
||||
return {"next_node": "aggregator"}
|
||||
|
||||
task = task_plan[task_index]
|
||||
shared_context = state.get("shared_context", {})
|
||||
|
||||
# 更新任务状态为 running
|
||||
updated_plan = self._update_task_status(task_plan, task.id, TaskStatus.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行任务
|
||||
result = await self.execute(task, shared_context)
|
||||
|
||||
# 更新任务状态
|
||||
if result.get("success"):
|
||||
updated_plan = self._update_task_status(
|
||||
updated_plan,
|
||||
task.id,
|
||||
TaskStatus.COMPLETED,
|
||||
result=result.get("content", "")
|
||||
)
|
||||
else:
|
||||
updated_plan = self._update_task_status(
|
||||
updated_plan,
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=result.get("error", "Unknown error")
|
||||
)
|
||||
|
||||
# 构建新上下文
|
||||
new_context = {**shared_context, **(result.get("context", {}))}
|
||||
|
||||
return {
|
||||
"task_plan": updated_plan,
|
||||
"results": {**state.get("results", {}), task.id: result},
|
||||
"shared_context": new_context,
|
||||
"next_node": "review"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 执行出错
|
||||
updated_plan = self._update_task_status(
|
||||
updated_plan,
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return {
|
||||
"task_plan": updated_plan,
|
||||
"results": {**state.get("results", {}), task.id: {"error": str(e)}},
|
||||
"next_node": "review"
|
||||
}
|
||||
|
||||
return node
|
||||
|
||||
def _update_task_status(
|
||||
self,
|
||||
tasks: list,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
result: Any = None,
|
||||
error: str = None
|
||||
) -> list:
|
||||
"""更新任务状态"""
|
||||
return [
|
||||
{
|
||||
**task.model_dump() if hasattr(task, 'model_dump') else task,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error
|
||||
}
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
def _build_messages(self, task: str, context: dict) -> list:
|
||||
"""构建消息列表"""
|
||||
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else "无"
|
||||
|
||||
user_prompt = self.system_prompt.format(
|
||||
task=task,
|
||||
context=context_str
|
||||
)
|
||||
|
||||
return [
|
||||
SystemMessage(content=user_prompt),
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
@@ -1,146 +0,0 @@
|
||||
"""
|
||||
Coder Worker - 代码编写和修改
|
||||
"""
|
||||
import json
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from .base import BaseWorker
|
||||
from ..types import TaskItem
|
||||
from ..prompts import CODER_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class CoderWorker(BaseWorker):
|
||||
"""Coder Worker - 代码编写和修改"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
tool_registry=None,
|
||||
tools: list = None
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
name="coder",
|
||||
system_prompt=CODER_SYSTEM_PROMPT,
|
||||
tools=tools or [],
|
||||
tool_registry=tool_registry
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""执行编码任务"""
|
||||
|
||||
# 构建消息
|
||||
messages = self._build_messages(task.description, context)
|
||||
|
||||
# 如果有代码执行工具,启用它
|
||||
if self.tool_registry:
|
||||
tool_defs = self._get_available_tools()
|
||||
if tool_defs:
|
||||
try:
|
||||
response = await self.llm.agenerate(
|
||||
messages=messages,
|
||||
tools=tool_defs
|
||||
)
|
||||
return self._handle_tool_response(response, messages)
|
||||
except Exception:
|
||||
# 如果工具调用失败,回退到普通调用
|
||||
pass
|
||||
|
||||
# 普通调用
|
||||
try:
|
||||
response = await self.llm.ainvoke(messages)
|
||||
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"code_written": True,
|
||||
"last_coder": self.name
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e),
|
||||
"context": {}
|
||||
}
|
||||
|
||||
def _get_available_tools(self) -> list:
|
||||
"""获取可用工具定义"""
|
||||
if not self.tool_registry:
|
||||
return []
|
||||
|
||||
tool_names = self.tools or ["search", "execute_code"]
|
||||
tool_defs = []
|
||||
|
||||
for tool_name in tool_names:
|
||||
tool_def = self.tool_registry.get_tool_definition(tool_name)
|
||||
if tool_def:
|
||||
tool_defs.append(tool_def)
|
||||
|
||||
return tool_defs
|
||||
|
||||
def _handle_tool_response(self, response, original_messages: list) -> dict:
|
||||
"""处理工具调用响应"""
|
||||
# 简化实现
|
||||
response_message = response.generations[0][0]
|
||||
|
||||
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||
# 有工具调用
|
||||
tool_results = []
|
||||
for tool_call in response_message.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
tool_func, _ = self.tool_registry.get_tool(tool_name)
|
||||
result = tool_func(**tool_args)
|
||||
tool_results.append({
|
||||
"tool": tool_name,
|
||||
"result": str(result)
|
||||
})
|
||||
except Exception as e:
|
||||
tool_results.append({
|
||||
"tool": tool_name,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 将工具结果添加到消息
|
||||
for msg in response.generations[0]:
|
||||
original_messages.append(msg)
|
||||
|
||||
for tool_result in tool_results:
|
||||
original_messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(tool_result, ensure_ascii=False)
|
||||
})
|
||||
|
||||
# 再次调用 LLM 生成最终响应
|
||||
final_response = await self.llm.ainvoke(original_messages)
|
||||
content = final_response.content if hasattr(final_response, 'content') else str(final_response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"code_written": True,
|
||||
"tool_results": tool_results,
|
||||
"last_coder": self.name
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 无工具调用
|
||||
content = response_message.text if hasattr(response_message, 'text') else str(response_message)
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"code_written": True,
|
||||
"last_coder": self.name
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
Research Worker - 信息搜索和调研
|
||||
"""
|
||||
import json
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from .base import BaseWorker
|
||||
from ..types import TaskItem
|
||||
from ..prompts import RESEARCH_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class ResearchWorker(BaseWorker):
|
||||
"""Research Worker - 信息搜索和调研"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
tool_registry=None,
|
||||
tools: list = None
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
name="research",
|
||||
system_prompt=RESEARCH_SYSTEM_PROMPT,
|
||||
tools=tools or [],
|
||||
tool_registry=tool_registry
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""执行调研任务"""
|
||||
|
||||
# 构建消息
|
||||
messages = self._build_messages(task.description, context)
|
||||
|
||||
try:
|
||||
# 调用 LLM
|
||||
response = await self.llm.ainvoke(messages)
|
||||
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 尝试提取搜索结果
|
||||
search_results = self._extract_search_results(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"research_results": search_results,
|
||||
"last_research_by": self.name
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e),
|
||||
"context": {}
|
||||
}
|
||||
|
||||
def _extract_search_results(self, content: str) -> list:
|
||||
"""从内容中提取搜索结果"""
|
||||
# 简单实现:查找以 - 或 * 开头的行
|
||||
results = []
|
||||
for line in content.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith(('- ', '* ', '1. ', '2. ', '3. ')):
|
||||
results.append(line.lstrip('-*123. '))
|
||||
|
||||
return results[:10] # 限制数量
|
||||
@@ -1,174 +0,0 @@
|
||||
"""
|
||||
Review Worker - 结果检查和质量评审
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from .base import BaseWorker
|
||||
from ..types import AgentState, TaskItem, TaskStatus, ReviewResult
|
||||
from ..prompts import REVIEW_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class ReviewWorker(BaseWorker):
|
||||
"""Review Worker - 结果检查和质量评审"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
tool_registry=None,
|
||||
tools: list = None
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
name="review",
|
||||
system_prompt=REVIEW_SYSTEM_PROMPT,
|
||||
tools=tools or [],
|
||||
tool_registry=tool_registry
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""执行评审任务"""
|
||||
|
||||
# 获取当前任务索引和任务计划
|
||||
# 注意:这里需要从 context 中获取更多信息
|
||||
|
||||
# 构建 prompt
|
||||
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else "无"
|
||||
|
||||
prompt = REVIEW_SYSTEM_PROMPT.format(
|
||||
original_task=context.get("original_task", ""),
|
||||
task_description=task.description,
|
||||
execution_result=task.result if task.result else "无结果",
|
||||
context=context_str
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用 LLM 进行评审
|
||||
response = await self.llm.ainvoke([
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content="请评审以上执行结果。")
|
||||
])
|
||||
|
||||
# 解析评审结果
|
||||
review_result = self._parse_review_response(response.content)
|
||||
|
||||
# 根据评审结果决定下一步
|
||||
if review_result.passed:
|
||||
# 通过,更新任务状态为 completed
|
||||
new_status = TaskStatus.COMPLETED
|
||||
next_node = "supervisor" # 返回 Supervisor 继续执行
|
||||
else:
|
||||
# 未通过,检查是否可重试
|
||||
if review_result.retryable:
|
||||
new_status = TaskStatus.NEEDS_RETRY
|
||||
next_node = "supervisor" # 返回 Supervisor 决定是否重试
|
||||
else:
|
||||
new_status = TaskStatus.FAILED
|
||||
next_node = "aggregator" # 失败,进入汇总
|
||||
|
||||
return {
|
||||
"success": review_result.passed,
|
||||
"content": response.content,
|
||||
"review_result": review_result.model_dump() if hasattr(review_result, 'model_dump') else dict(review_result),
|
||||
"context": {
|
||||
"review_passed": review_result.passed,
|
||||
"issues": review_result.issues,
|
||||
"last_review_by": self.name
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e),
|
||||
"context": {}
|
||||
}
|
||||
|
||||
def _parse_review_response(self, response: str) -> ReviewResult:
|
||||
"""解析评审响应"""
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("No JSON found")
|
||||
|
||||
return ReviewResult(
|
||||
passed=data.get("passed", True),
|
||||
issues=data.get("issues", []),
|
||||
suggestions=data.get("suggestions", []),
|
||||
retryable=data.get("retryable", True)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# 解析失败,默认通过
|
||||
return ReviewResult(
|
||||
passed=True,
|
||||
issues=[],
|
||||
suggestions=[],
|
||||
retryable=True
|
||||
)
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
async def node(state: AgentState) -> dict:
|
||||
task_index = state.get("current_task_index", 0)
|
||||
task_plan = state.get("task_plan", [])
|
||||
|
||||
if task_index >= len(task_plan):
|
||||
return {"next_node": "aggregator"}
|
||||
|
||||
task = task_plan[task_index]
|
||||
shared_context = {
|
||||
**state.get("shared_context", {}),
|
||||
"original_task": state.get("original_task", "")
|
||||
}
|
||||
|
||||
try:
|
||||
# 执行评审
|
||||
result = await self.execute(task, shared_context)
|
||||
|
||||
# 更新任务状态
|
||||
review_passed = result.get("review_result", {}).get("passed", True)
|
||||
retryable = result.get("review_result", {}).get("retryable", True)
|
||||
|
||||
if review_passed:
|
||||
updated_status = TaskStatus.COMPLETED
|
||||
elif retryable:
|
||||
updated_status = TaskStatus.NEEDS_RETRY
|
||||
else:
|
||||
updated_status = TaskStatus.FAILED
|
||||
|
||||
updated_plan = self._update_task_status(
|
||||
task_plan,
|
||||
task.id,
|
||||
updated_status,
|
||||
result=task.result
|
||||
)
|
||||
|
||||
# 确定下一步
|
||||
if updated_status == TaskStatus.COMPLETED:
|
||||
next_node = "supervisor"
|
||||
elif updated_status == TaskStatus.NEEDS_RETRY:
|
||||
next_node = "supervisor"
|
||||
else:
|
||||
next_node = "aggregator"
|
||||
|
||||
return {
|
||||
"task_plan": updated_plan,
|
||||
"results": {**state.get("results", {}), f"{task.id}_review": result},
|
||||
"shared_context": {**shared_context, **result.get("context", {})},
|
||||
"next_node": next_node
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"results": {**state.get("results", {}), f"{task.id}_review": {"error": str(e)}}
|
||||
}
|
||||
|
||||
return node
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
工具实现模块
|
||||
"""
|
||||
|
||||
# 基础工具
|
||||
from . import search
|
||||
from . import calculator
|
||||
from . import time_tool
|
||||
|
||||
# 安全工具
|
||||
from . import sandbox
|
||||
from . import database
|
||||
from . import api_client
|
||||
|
||||
__all__ = [
|
||||
"search",
|
||||
"calculator",
|
||||
"time_tool",
|
||||
"sandbox",
|
||||
"database",
|
||||
"api_client",
|
||||
]
|
||||
@@ -1,166 +0,0 @@
|
||||
"""
|
||||
API 调用工具 - 安全的外部 API 调用
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class APIPermission(Enum):
|
||||
"""API 权限级别"""
|
||||
PUBLIC = "public" # 公开 API
|
||||
APPROVED = "approved" # 已审批的 API
|
||||
ADMIN = "admin" # 管理员 API
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIEndpoint:
|
||||
"""API 端点定义"""
|
||||
name: str
|
||||
url: str
|
||||
method: str
|
||||
permission: APIPermission
|
||||
description: str
|
||||
rate_limit: int = 60 # 每分钟请求次数
|
||||
|
||||
|
||||
# API 白名单
|
||||
ALLOWED_APIS = [
|
||||
APIEndpoint(
|
||||
name="weather",
|
||||
url="https://api.weather.example.com/v1",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取天气信息",
|
||||
rate_limit=30
|
||||
),
|
||||
APIEndpoint(
|
||||
name="news",
|
||||
url="https://newsapi.org/v2",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取新闻",
|
||||
rate_limit=30
|
||||
),
|
||||
# 可以添加更多已审批的 API
|
||||
]
|
||||
|
||||
|
||||
class APICallTool:
|
||||
"""
|
||||
API 调用工具
|
||||
|
||||
安全特性:
|
||||
- 只允许调用白名单中的 API
|
||||
- 速率限制
|
||||
- 请求超时
|
||||
- 响应大小限制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.allowed_apis = {api.name: api for api in ALLOWED_APIS}
|
||||
self.request_timeout = 10 # 请求超时(秒)
|
||||
self.max_response_size = 1024 * 1024 # 最大响应大小(1MB)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 API
|
||||
|
||||
Args:
|
||||
api_name: API 名称(必须在白名单中)
|
||||
endpoint: 具体的端点
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
|
||||
Returns:
|
||||
API 响应
|
||||
"""
|
||||
# 安全检查1: API 必须在白名单中
|
||||
if api_name not in self.allowed_apis:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"API '{api_name}' not in whitelist. Allowed: {list(self.allowed_apis.keys())}"
|
||||
}
|
||||
|
||||
api = self.allowed_apis[api_name]
|
||||
|
||||
# 构建完整 URL
|
||||
url = f"{api.url}/{endpoint}" if endpoint else api.url
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.request_timeout) as client:
|
||||
# 根据方法调用
|
||||
if api.method == "GET":
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
elif api.method == "POST":
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Method {api.method} not supported"
|
||||
}
|
||||
|
||||
# 检查响应大小
|
||||
if len(response.content) > self.max_response_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large (max {self.max_response_size} bytes)"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"data": response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text,
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Request timeout"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def list_apis(self) -> list:
|
||||
"""列出所有可用的 API"""
|
||||
return [
|
||||
{
|
||||
"name": api.name,
|
||||
"description": api.description,
|
||||
"method": api.method,
|
||||
"permission": api.permission.value,
|
||||
"rate_limit": api.rate_limit
|
||||
}
|
||||
for api in ALLOWED_APIS
|
||||
]
|
||||
|
||||
|
||||
# 全局实例
|
||||
api_tool = APICallTool()
|
||||
|
||||
|
||||
async def call_api(
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
API 调用工具(供 Agent 使用)
|
||||
"""
|
||||
return await api_tool.call(api_name, endpoint, params)
|
||||
|
||||
|
||||
def list_allowed_apis() -> list:
|
||||
"""列出允许的 API"""
|
||||
return api_tool.list_apis()
|
||||
@@ -1,91 +0,0 @@
|
||||
"""
|
||||
计算器工具
|
||||
"""
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
|
||||
# 安全运算符
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Pow: operator.pow,
|
||||
ast.Mod: operator.mod,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
|
||||
def safe_eval_expr(node):
|
||||
"""安全地求值表达式节点"""
|
||||
if isinstance(node, ast.Num):
|
||||
return node.n
|
||||
elif isinstance(node, ast.BinOp):
|
||||
left = safe_eval_expr(node.left)
|
||||
right = safe_eval_expr(node.right)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](left, right)
|
||||
raise ValueError(f"Unsupported operator: {op_type}")
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = safe_eval_expr(node.operand)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](operand)
|
||||
raise ValueError(f"Unsupported unary operator: {op_type}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported expression: {ast.dump(node)}")
|
||||
|
||||
|
||||
def calculate(expression: str) -> dict:
|
||||
"""
|
||||
执行数学计算
|
||||
|
||||
Args:
|
||||
expression: 数学表达式,如 "2 + 2" 或 "sqrt(16)"
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
try:
|
||||
# 预处理:处理常见数学函数
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
expression = expression.replace("pi", "3.14159265359")
|
||||
expression = expression.replace("e", "2.71828182846")
|
||||
|
||||
# 解析表达式
|
||||
tree = ast.parse(expression, mode='eval')
|
||||
result = safe_eval_expr(tree.body)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"expression": expression,
|
||||
"result": result,
|
||||
"type": type(result).__name__
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"expression": expression,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "calculator",
|
||||
"description": "Perform mathematical calculations. Supports basic arithmetic (+, -, *, /), powers (**), and functions (sqrt).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression to evaluate, e.g., '2 + 2' or 'sqrt(16) + 5'"
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
数据库查询工具 - 安全的数据查询接口
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
import os
|
||||
|
||||
|
||||
# 只读查询白名单 - 只允许 SELECT 语句
|
||||
ALLOWED_TABLES = ["users", "agents", "sessions", "audit_logs"]
|
||||
|
||||
|
||||
class DatabaseQueryTool:
|
||||
"""
|
||||
数据库查询工具
|
||||
|
||||
安全特性:
|
||||
- 只允许 SELECT 查询
|
||||
- 表名白名单
|
||||
- 结果数量限制
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str = ""):
|
||||
self.connection_string = connection_string or os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://postgres:postgres@localhost:5432/x_agents"
|
||||
)
|
||||
self.max_rows = 100 # 最多返回100行
|
||||
|
||||
def query(self, sql: str, params: List[Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
执行查询
|
||||
|
||||
Args:
|
||||
sql: SQL 查询语句(必须是 SELECT)
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
# 安全检查1: 必须是 SELECT 语句
|
||||
sql_upper = sql.strip().upper()
|
||||
if not sql_upper.startswith("SELECT"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only SELECT queries are allowed"
|
||||
}
|
||||
|
||||
# 安全检查2: 禁止危险关键字
|
||||
dangerous_keywords = [
|
||||
"DROP", "DELETE", "INSERT", "UPDATE", "ALTER",
|
||||
"CREATE", "TRUNCATE", "EXEC", "EXECUTE"
|
||||
]
|
||||
for keyword in dangerous_keywords:
|
||||
if keyword in sql_upper:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Keyword '{keyword}' is not allowed"
|
||||
}
|
||||
|
||||
# 安全检查3: 表名白名单
|
||||
for table in ALLOWED_TABLES:
|
||||
if f"FROM {table}" in sql_upper or f"JOIN {table}" in sql_upper:
|
||||
# 表名在白名单中,允许
|
||||
break
|
||||
else:
|
||||
# 没有找到白名单表
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Table not in whitelist. Allowed: {ALLOWED_TABLES}"
|
||||
}
|
||||
|
||||
# TODO: 实际执行查询(需要数据库连接)
|
||||
# 这里返回模拟数据
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Query executed (mock mode - database not connected)",
|
||||
"rows": [],
|
||||
"columns": []
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
db_tool = DatabaseQueryTool()
|
||||
|
||||
|
||||
def query_data(sql: str) -> Dict[str, Any]:
|
||||
"""
|
||||
查询数据工具
|
||||
|
||||
Args:
|
||||
sql: SELECT 查询语句
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
return db_tool.query(sql)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""
|
||||
网页搜索工具
|
||||
"""
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
|
||||
async def search_web(query: str, max_results: int = 5) -> dict:
|
||||
"""
|
||||
搜索网页获取信息
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
# 这里可以使用搜索引擎API,如 Google, Bing, DuckDuckGo 等
|
||||
# 示例使用 DuckDuckGo API(免费)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://api.duckduckgo.com/",
|
||||
params={
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": 1,
|
||||
"skip_disambig": 1
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = []
|
||||
|
||||
# 提取相关主题
|
||||
if "RelatedTopics" in data:
|
||||
for item in data["RelatedTopics"][:max_results]:
|
||||
if "Text" in item:
|
||||
results.append({
|
||||
"title": item.get("Text", "").split(" - ")[0] if " - " in item.get("Text", "") else "",
|
||||
"content": item.get("Text", ""),
|
||||
"url": item.get("URL", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search API returned status {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义(用于 LLM)
|
||||
TOOL_DEFINITION = {
|
||||
"name": "search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
时间工具
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_current_time(timezone: Optional[str] = None) -> dict:
|
||||
"""
|
||||
获取当前时间
|
||||
|
||||
Args:
|
||||
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
|
||||
|
||||
Returns:
|
||||
当前时间信息
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"datetime": now.isoformat(),
|
||||
"timestamp": now.timestamp(),
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"weekday": now.strftime("%A"),
|
||||
"timezone": timezone or "Local Time"
|
||||
}
|
||||
|
||||
|
||||
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> dict:
|
||||
"""
|
||||
格式化时间戳
|
||||
|
||||
Args:
|
||||
timestamp: Unix 时间戳
|
||||
format_str: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的时间
|
||||
"""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return {
|
||||
"success": True,
|
||||
"formatted": dt.strftime(format_str),
|
||||
"datetime": dt.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
"""
|
||||
工具注册表 - 管理所有可用工具(白名单机制)
|
||||
"""
|
||||
from typing import Any, Callable, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""工具安全等级"""
|
||||
SAFE = "safe" # 安全操作
|
||||
REVIEW = "review" # 需要审核
|
||||
DANGER = "danger" # 危险操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""工具元数据"""
|
||||
name: str
|
||||
description: str
|
||||
security_level: str
|
||||
require_approval: bool = False
|
||||
allowed_roles: list = None
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"security_level": self.security_level,
|
||||
"require_approval": self.require_approval
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, tuple[Callable, ToolMetadata]] = {}
|
||||
self._definitions: dict[str, dict] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable,
|
||||
description: str = "",
|
||||
security_level: str = "safe",
|
||||
require_approval: bool = False,
|
||||
allowed_roles: list = None,
|
||||
parameters: dict = None
|
||||
):
|
||||
"""注册工具到白名单"""
|
||||
metadata = ToolMetadata(
|
||||
name=name,
|
||||
description=description,
|
||||
security_level=security_level,
|
||||
require_approval=require_approval,
|
||||
allowed_roles=allowed_roles or ["user", "admin"]
|
||||
)
|
||||
|
||||
self._tools[name] = (func, metadata)
|
||||
|
||||
# 生成工具定义(用于 LLM 调用)
|
||||
self._definitions[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
||||
"""获取工具函数和元数据"""
|
||||
if name not in self._tools:
|
||||
raise ValueError(f"Tool '{name}' not found in whitelist")
|
||||
return self._tools[name]
|
||||
|
||||
def get_tool_definition(self, name: str) -> Optional[dict]:
|
||||
"""获取工具定义(用于 LLM)"""
|
||||
return self._definitions.get(name)
|
||||
|
||||
def list_tools(self) -> list[ToolMetadata]:
|
||||
"""列出所有已注册工具"""
|
||||
return [meta for _, meta in self._tools.values()]
|
||||
|
||||
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
||||
"""检查用户权限"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return user_role in metadata.allowed_roles
|
||||
|
||||
def need_approval(self, tool_name: str) -> bool:
|
||||
"""判断是否需要审批"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return metadata.require_approval
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
沙盒执行环境 - 在项目内构建,不依赖 Docker
|
||||
提供安全的代码执行环境
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import resource
|
||||
import signal
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SandboxConfig:
|
||||
"""沙盒配置"""
|
||||
# 资源限制
|
||||
MAX_MEMORY_MB = 256 # 最大内存 (MB)
|
||||
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""
|
||||
沙盒执行器 - 使用 subprocess 隔离执行
|
||||
|
||||
安全特性:
|
||||
- 内存限制
|
||||
- CPU限制
|
||||
- 超时控制
|
||||
- 网络隔离(可选)
|
||||
- 临时文件隔离
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SandboxConfig] = None):
|
||||
self.config = config or SandboxConfig()
|
||||
self.temp_dir = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception as e:
|
||||
print(f"Cleanup error: {e}")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙盒中执行代码
|
||||
|
||||
Args:
|
||||
code: 要执行的代码
|
||||
language: 语言类型 (python, javascript)
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
if language == "python":
|
||||
return self._execute_python(code, timeout)
|
||||
elif language == "javascript":
|
||||
return self._execute_javascript(code, timeout)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unsupported language: {language}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 Python 代码"""
|
||||
# 创建临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 构建命令
|
||||
cmd = ["python", temp_file]
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir, # 限制工作目录
|
||||
env=self._get_restricted_env(), # 限制环境变量
|
||||
)
|
||||
|
||||
# 检查输出大小
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 JavaScript 代码"""
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 尝试使用 node
|
||||
cmd = ["node", temp_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _get_restricted_env(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取受限的环境变量
|
||||
移除敏感变量,保留必要的 PATH
|
||||
"""
|
||||
# 保留 PATH,移除其他敏感变量
|
||||
safe_env = {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir,
|
||||
"TMPDIR": self.temp_dir,
|
||||
}
|
||||
|
||||
# 移除可能不安全的变量
|
||||
unsafe_vars = [
|
||||
"PYTHONPATH",
|
||||
"PYTHONHOME",
|
||||
"LD_PRELOAD",
|
||||
"LD_LIBRARY_PATH",
|
||||
]
|
||||
|
||||
for var in unsafe_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
return safe_env
|
||||
|
||||
|
||||
class SafeEval:
|
||||
"""
|
||||
安全求值器 - 用于简单表达式计算
|
||||
比沙盒更轻量,适用于不需要完全隔离的场景
|
||||
"""
|
||||
|
||||
# 安全函数白名单
|
||||
SAFE_BUILTINS = {
|
||||
"abs": abs,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
"round": round,
|
||||
"pow": pow,
|
||||
"print": print,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
}
|
||||
|
||||
# 安全数学常量
|
||||
SAFE_MATH = {
|
||||
"pi": 3.14159265359,
|
||||
"e": 2.71828182846,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def eval(cls, expression: str) -> Any:
|
||||
"""
|
||||
安全地求值表达式
|
||||
|
||||
Args:
|
||||
expression: 数学表达式
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
# 预处理表达式
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
|
||||
# 构建安全命名空间
|
||||
safe_namespace = {
|
||||
**cls.SAFE_BUILTINS,
|
||||
**cls.SAFE_MATH,
|
||||
"__builtins__": {} # 禁用__builtins__
|
||||
}
|
||||
|
||||
try:
|
||||
result = eval(expression, safe_namespace)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Evaluation error: {e}")
|
||||
|
||||
|
||||
# 全局沙盒实例
|
||||
sandbox = Sandbox()
|
||||
|
||||
|
||||
# 装饰器:快速将函数封装为沙盒执行
|
||||
def sandboxed(timeout: int = 30):
|
||||
"""装饰器:为函数添加沙盒执行能力"""
|
||||
def decorator(func):
|
||||
def wrapper(code: str, *args, **kwargs):
|
||||
result = sandbox.execute(code, timeout=timeout)
|
||||
if not result["success"]:
|
||||
raise RuntimeError(result.get("error", "Execution failed"))
|
||||
return result["output"]
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
API 路由定义
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.approval import ApprovalService
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局依赖(实际应该注入)
|
||||
_agent_manager: Optional[AgentManager] = None
|
||||
_approval_service: Optional[ApprovalService] = None
|
||||
|
||||
|
||||
def get_agent_manager() -> AgentManager:
|
||||
"""获取 Agent 管理器"""
|
||||
# 这里应该从 app.state 获取
|
||||
from app.main import agent_manager
|
||||
if agent_manager is None:
|
||||
raise HTTPException(status_code=503, detail="Agent service not initialized")
|
||||
return agent_manager
|
||||
|
||||
|
||||
def get_approval_service() -> ApprovalService:
|
||||
"""获取审批服务"""
|
||||
global _approval_service
|
||||
if _approval_service is None:
|
||||
_approval_service = ApprovalService()
|
||||
return _approval_service
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求"""
|
||||
agent_id: str
|
||||
message: str
|
||||
session_id: str = ""
|
||||
context: dict = {}
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""聊天响应"""
|
||||
reply: str
|
||||
session_id: str
|
||||
tools_used: list[str] = []
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""审批请求"""
|
||||
request_id: str
|
||||
tool_name: str
|
||||
params: dict
|
||||
reason: str
|
||||
approved: bool
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(
|
||||
request: ChatRequest,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""处理 Agent 聊天请求"""
|
||||
try:
|
||||
# 生成会话ID
|
||||
if not request.session_id:
|
||||
import uuid
|
||||
request.session_id = str(uuid.uuid4())
|
||||
|
||||
# 执行 Agent
|
||||
result = await agent_manager.execute(
|
||||
agent_id=request.agent_id,
|
||||
message=request.message,
|
||||
session_id=request.session_id,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
reply=result.get("reply", ""),
|
||||
session_id=request.session_id,
|
||||
tools_used=result.get("tools_used", []),
|
||||
metadata=result.get("metadata", {})
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tool/request")
|
||||
async def request_tool_execution(
|
||||
request: dict,
|
||||
approval_service: ApprovalService = Depends(get_approval_service)
|
||||
):
|
||||
"""请求执行工具(需要审批)"""
|
||||
tool_name = request.get("tool_name")
|
||||
params = request.get("params", {})
|
||||
user_id = request.get("user_id", "unknown")
|
||||
agent_id = request.get("agent_id")
|
||||
reason = request.get("reason", "")
|
||||
|
||||
# 创建审批请求
|
||||
request_id = await approval_service.request_approval(
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id or "",
|
||||
reason=reason
|
||||
)
|
||||
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
async def list_tools(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有可用工具"""
|
||||
tools = agent_manager.list_tools()
|
||||
return {"tools": [tool.dict() for tool in tools]}
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
async def list_agents(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有已加载的 Agent"""
|
||||
agents = agent_manager.list_agents()
|
||||
return {"agents": agents}
|
||||
|
||||
|
||||
@router.get("/agent/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""获取特定 Agent 信息"""
|
||||
agent_info = agent_manager.get_agent_info(agent_id)
|
||||
if not agent_info:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return agent_info
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
Core 模块 - AI 核心能力
|
||||
"""
|
||||
from . import tools
|
||||
|
||||
__all__ = [
|
||||
"tools",
|
||||
]
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
Agent 工具模块
|
||||
"""
|
||||
from .registry import ToolRegistry, ToolMetadata, SecurityLevel, global_registry
|
||||
from . import impl
|
||||
|
||||
# 导入所有工具函数和定义
|
||||
from .impl import (
|
||||
# 文件操作
|
||||
read_file,
|
||||
write_file,
|
||||
list_dir,
|
||||
delete_file,
|
||||
search_files,
|
||||
READ_FILE_TOOL,
|
||||
WRITE_FILE_TOOL,
|
||||
LIST_DIR_TOOL,
|
||||
DELETE_FILE_TOOL,
|
||||
SEARCH_FILES_TOOL,
|
||||
|
||||
# 代码执行
|
||||
execute_python,
|
||||
execute_javascript,
|
||||
execute_bash,
|
||||
EXECUTE_PYTHON_TOOL,
|
||||
EXECUTE_JAVASCRIPT_TOOL,
|
||||
EXECUTE_BASH_TOOL,
|
||||
|
||||
# 网页
|
||||
web_fetch,
|
||||
web_search,
|
||||
WEB_FETCH_TOOL,
|
||||
WEB_SEARCH_TOOL,
|
||||
|
||||
# HTTP
|
||||
http_request,
|
||||
http_get,
|
||||
http_post,
|
||||
http_put,
|
||||
http_delete,
|
||||
HTTP_REQUEST_TOOL,
|
||||
|
||||
# 通知
|
||||
send_notification,
|
||||
send_email,
|
||||
send_webhook,
|
||||
SEND_NOTIFICATION_TOOL,
|
||||
|
||||
# 时间
|
||||
get_current_time,
|
||||
format_time,
|
||||
GET_CURRENT_TIME_TOOL,
|
||||
)
|
||||
|
||||
|
||||
def register_all_tools(registry: ToolRegistry = None):
|
||||
"""
|
||||
注册所有工具到注册表
|
||||
|
||||
Args:
|
||||
registry: 工具注册表,默认使用全局注册表
|
||||
"""
|
||||
reg = registry or global_registry
|
||||
|
||||
# 文件操作
|
||||
reg.register("read_file", read_file, READ_FILE_TOOL["description"], "safe", parameters=READ_FILE_TOOL["parameters"])
|
||||
reg.register("write_file", write_file, WRITE_FILE_TOOL["description"], "review", parameters=WRITE_FILE_TOOL["parameters"])
|
||||
reg.register("list_dir", list_dir, LIST_DIR_TOOL["description"], "safe", parameters=LIST_DIR_TOOL["parameters"])
|
||||
reg.register("delete_file", delete_file, DELETE_FILE_TOOL["description"], "danger", parameters=DELETE_FILE_TOOL["parameters"])
|
||||
reg.register("search_files", search_files, SEARCH_FILES_TOOL["description"], "safe", parameters=SEARCH_FILES_TOOL["parameters"])
|
||||
|
||||
# 代码执行
|
||||
reg.register("execute_python", execute_python, EXECUTE_PYTHON_TOOL["description"], "review", parameters=EXECUTE_PYTHON_TOOL["parameters"])
|
||||
reg.register("execute_javascript", execute_javascript, EXECUTE_JAVASCRIPT_TOOL["description"], "review", parameters=EXECUTE_JAVASCRIPT_TOOL["parameters"])
|
||||
reg.register("execute_bash", execute_bash, EXECUTE_BASH_TOOL["description"], "danger", parameters=EXECUTE_BASH_TOOL["parameters"])
|
||||
|
||||
# 网页
|
||||
reg.register("web_fetch", web_fetch, WEB_FETCH_TOOL["description"], "safe", parameters=WEB_FETCH_TOOL["parameters"])
|
||||
reg.register("web_search", web_search, WEB_SEARCH_TOOL["description"], "safe", parameters=WEB_SEARCH_TOOL["parameters"])
|
||||
|
||||
# HTTP
|
||||
reg.register("http_request", http_request, HTTP_REQUEST_TOOL["description"], "safe", parameters=HTTP_REQUEST_TOOL["parameters"])
|
||||
|
||||
# 通知
|
||||
reg.register("send_notification", send_notification, SEND_NOTIFICATION_TOOL["description"], "safe", parameters=SEND_NOTIFICATION_TOOL["parameters"])
|
||||
|
||||
# 时间
|
||||
reg.register("get_current_time", get_current_time, GET_CURRENT_TIME_TOOL["description"], "safe", parameters=GET_CURRENT_TIME_TOOL["parameters"])
|
||||
|
||||
return reg
|
||||
|
||||
|
||||
# 注册所有工具
|
||||
register_all_tools(global_registry)
|
||||
|
||||
__all__ = [
|
||||
"ToolRegistry",
|
||||
"ToolMetadata",
|
||||
"SecurityLevel",
|
||||
"global_registry",
|
||||
"register_all_tools",
|
||||
"impl",
|
||||
|
||||
# 所有工具函数
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_dir",
|
||||
"delete_file",
|
||||
"search_files",
|
||||
|
||||
"execute_python",
|
||||
"execute_javascript",
|
||||
"execute_bash",
|
||||
|
||||
"web_fetch",
|
||||
"web_search",
|
||||
|
||||
"http_request",
|
||||
"http_get",
|
||||
"http_post",
|
||||
"http_put",
|
||||
"http_delete",
|
||||
|
||||
"send_notification",
|
||||
"send_email",
|
||||
"send_webhook",
|
||||
|
||||
"get_current_time",
|
||||
"format_time",
|
||||
]
|
||||
@@ -1,100 +0,0 @@
|
||||
"""
|
||||
工具实现模块
|
||||
"""
|
||||
from .files import (
|
||||
read_file,
|
||||
write_file,
|
||||
list_dir,
|
||||
delete_file,
|
||||
search_files,
|
||||
READ_FILE_TOOL,
|
||||
WRITE_FILE_TOOL,
|
||||
LIST_DIR_TOOL,
|
||||
DELETE_FILE_TOOL,
|
||||
SEARCH_FILES_TOOL,
|
||||
)
|
||||
|
||||
from .executor import (
|
||||
execute_python,
|
||||
execute_javascript,
|
||||
execute_bash,
|
||||
EXECUTE_PYTHON_TOOL,
|
||||
EXECUTE_JAVASCRIPT_TOOL,
|
||||
EXECUTE_BASH_TOOL,
|
||||
)
|
||||
|
||||
from .web import (
|
||||
web_fetch,
|
||||
web_search,
|
||||
WEB_FETCH_TOOL,
|
||||
WEB_SEARCH_TOOL,
|
||||
)
|
||||
|
||||
from .http import (
|
||||
http_request,
|
||||
http_get,
|
||||
http_post,
|
||||
http_put,
|
||||
http_delete,
|
||||
HTTP_REQUEST_TOOL,
|
||||
)
|
||||
|
||||
from .notify import (
|
||||
send_notification,
|
||||
send_email,
|
||||
send_webhook,
|
||||
SEND_NOTIFICATION_TOOL,
|
||||
)
|
||||
|
||||
from .time_tool import (
|
||||
get_current_time,
|
||||
format_time,
|
||||
GET_CURRENT_TIME_TOOL,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 文件操作
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_dir",
|
||||
"delete_file",
|
||||
"search_files",
|
||||
"READ_FILE_TOOL",
|
||||
"WRITE_FILE_TOOL",
|
||||
"LIST_DIR_TOOL",
|
||||
"DELETE_FILE_TOOL",
|
||||
"SEARCH_FILES_TOOL",
|
||||
|
||||
# 代码执行
|
||||
"execute_python",
|
||||
"execute_javascript",
|
||||
"execute_bash",
|
||||
"EXECUTE_PYTHON_TOOL",
|
||||
"EXECUTE_JAVASCRIPT_TOOL",
|
||||
"EXECUTE_BASH_TOOL",
|
||||
|
||||
# 网页
|
||||
"web_fetch",
|
||||
"web_search",
|
||||
"WEB_FETCH_TOOL",
|
||||
"WEB_SEARCH_TOOL",
|
||||
|
||||
# HTTP
|
||||
"http_request",
|
||||
"http_get",
|
||||
"http_post",
|
||||
"http_put",
|
||||
"http_delete",
|
||||
"HTTP_REQUEST_TOOL",
|
||||
|
||||
# 通知
|
||||
"send_notification",
|
||||
"send_email",
|
||||
"send_webhook",
|
||||
"SEND_NOTIFICATION_TOOL",
|
||||
|
||||
# 时间
|
||||
"get_current_time",
|
||||
"format_time",
|
||||
"GET_CURRENT_TIME_TOOL",
|
||||
]
|
||||
@@ -1,334 +0,0 @@
|
||||
"""
|
||||
代码执行工具
|
||||
提供安全的Python、JavaScript、Bash代码执行
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ExecutorConfig:
|
||||
"""执行器配置"""
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间(秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小(1MB)
|
||||
MAX_MEMORY_MB = 256 # 最大内存(MB)
|
||||
ALLOWED_PYTHON_PACKAGES = [] # 允许的Python包(空=仅标准库)
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
"""
|
||||
代码执行器 - 在沙盒环境中执行代码
|
||||
|
||||
安全特性:
|
||||
- 临时目录隔离
|
||||
- 超时控制
|
||||
- 输出大小限制
|
||||
- 环境变量限制
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ExecutorConfig] = None):
|
||||
self.config = config or ExecutorConfig()
|
||||
self.temp_dir: Optional[str] = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="executor_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_safe_env(self) -> Dict[str, str]:
|
||||
"""获取安全的环境变量"""
|
||||
return {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir or "/tmp",
|
||||
"TMPDIR": self.temp_dir or "/tmp",
|
||||
}
|
||||
|
||||
def execute_python(
|
||||
self,
|
||||
code: str,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行Python代码
|
||||
|
||||
Args:
|
||||
code: Python代码
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
# 写入临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
["python", temp_file],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
env=self._get_safe_env(),
|
||||
)
|
||||
|
||||
return self._process_result(result)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"language": "python"
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Python not installed",
|
||||
"language": "python"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"language": "python"
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def execute_javascript(
|
||||
self,
|
||||
code: str,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行JavaScript代码
|
||||
|
||||
Args:
|
||||
code: JavaScript代码
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
# 写入临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
["node", temp_file],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
return self._process_result(result)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"language": "javascript"
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"language": "javascript"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"language": "javascript"
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def execute_bash(
|
||||
self,
|
||||
command: str,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行Bash命令
|
||||
|
||||
Args:
|
||||
command: Bash命令
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
self._setup_temp_dir()
|
||||
|
||||
# 安全检查:禁止的危险命令
|
||||
dangerous_patterns = [
|
||||
"rm -rf /",
|
||||
"mkfs",
|
||||
"dd if=",
|
||||
">:/dev/sd",
|
||||
"chmod 777 /",
|
||||
"chown -R",
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in command:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Dangerous command blocked: {pattern}",
|
||||
"language": "bash"
|
||||
}
|
||||
|
||||
try:
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
["bash", "-c", command],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
env=self._get_safe_env(),
|
||||
)
|
||||
|
||||
return self._process_result(result)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"language": "bash"
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Bash not installed",
|
||||
"language": "bash"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"language": "bash"
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _process_result(self, result: subprocess.CompletedProcess) -> Dict[str, Any]:
|
||||
"""处理执行结果"""
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
# 截断输出
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
|
||||
# 全局执行器实例
|
||||
executor = CodeExecutor()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def execute_python(code: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""执行Python代码"""
|
||||
return executor.execute_python(code, timeout)
|
||||
|
||||
|
||||
def execute_javascript(code: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""执行JavaScript代码"""
|
||||
return executor.execute_javascript(code, timeout)
|
||||
|
||||
|
||||
def execute_bash(command: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""执行Bash命令"""
|
||||
return executor.execute_bash(command, timeout)
|
||||
|
||||
|
||||
# 工具定义
|
||||
EXECUTE_PYTHON_TOOL = {
|
||||
"name": "execute_python",
|
||||
"description": "Execute Python code in a sandboxed environment. Use this for Python programming tasks, calculations, and data processing.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The Python code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30, max: 60)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
|
||||
EXECUTE_JAVASCRIPT_TOOL = {
|
||||
"name": "execute_javascript",
|
||||
"description": "Execute JavaScript code in a sandboxed environment. Use this for JavaScript programming tasks.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The JavaScript code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
|
||||
EXECUTE_BASH_TOOL = {
|
||||
"name": "execute_bash",
|
||||
"description": "Execute a bash command in a sandboxed environment. Use this for shell operations, file management, and system commands.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
@@ -1,444 +0,0 @@
|
||||
"""
|
||||
文件操作工具
|
||||
提供安全的文件读写、目录操作、搜索功能
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import glob as glob_module
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class FileToolConfig:
|
||||
"""文件工具配置"""
|
||||
# 允许访问的基础目录(限制在项目内)
|
||||
ALLOWED_BASE_DIRS = [
|
||||
"account", # 用户工作区
|
||||
"temp", # 临时文件
|
||||
]
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
MAX_SEARCH_RESULTS = 100
|
||||
|
||||
|
||||
def _resolve_safe_path(base_path: str, relative_path: str) -> str:
|
||||
"""
|
||||
解析安全的文件路径
|
||||
确保路径不会超出基础目录
|
||||
"""
|
||||
# 规范化路径
|
||||
full_path = os.path.normpath(os.path.join(base_path, relative_path))
|
||||
|
||||
# 检查是否在允许的基础目录内
|
||||
path_parts = Path(full_path).parts
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError("Invalid path: too short")
|
||||
|
||||
base_dir = path_parts[0]
|
||||
if base_dir not in FileToolConfig.ALLOWED_BASE_DIRS and not base_dir.endswith(".py"):
|
||||
# 允许 account 下的子目录
|
||||
if len(path_parts) >= 2 and path_parts[0] != "account":
|
||||
raise ValueError(f"Path not in allowed directories: {base_dir}")
|
||||
|
||||
return full_path
|
||||
|
||||
|
||||
def read_file(file_path: str, encoding: str = "utf-8") -> Dict[str, Any]:
|
||||
"""
|
||||
读取文件内容
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
encoding: 文件编码
|
||||
|
||||
Returns:
|
||||
文件内容
|
||||
"""
|
||||
try:
|
||||
# 安全检查
|
||||
full_path = _resolve_safe_path("", file_path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"File not found: {file_path}"
|
||||
}
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Not a file: {file_path}"
|
||||
}
|
||||
|
||||
# 检查文件大小
|
||||
file_size = os.path.getsize(full_path)
|
||||
if file_size > FileToolConfig.MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"File too large: {file_size} bytes (max {FileToolConfig.MAX_FILE_SIZE})"
|
||||
}
|
||||
|
||||
# 读取内容
|
||||
with open(full_path, "r", encoding=encoding, errors="replace") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"file_path": file_path,
|
||||
"size": file_size,
|
||||
"encoding": encoding
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Read error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def write_file(file_path: str, content: str, encoding: str = "utf-8") -> Dict[str, Any]:
|
||||
"""
|
||||
写入文件内容
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
content: 文件内容
|
||||
encoding: 文件编码
|
||||
|
||||
Returns:
|
||||
写入结果
|
||||
"""
|
||||
try:
|
||||
# 安全检查
|
||||
full_path = _resolve_safe_path("", file_path)
|
||||
|
||||
# 检查内容大小
|
||||
if len(content.encode(encoding)) > FileToolConfig.MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content too large: {len(content)} bytes"
|
||||
}
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
|
||||
# 写入内容
|
||||
with open(full_path, "w", encoding=encoding) as f:
|
||||
f.write(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": file_path,
|
||||
"bytes_written": len(content.encode(encoding))
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Write error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def list_dir(dir_path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
列出目录内容
|
||||
|
||||
Args:
|
||||
dir_path: 目录路径
|
||||
|
||||
Returns:
|
||||
目录内容列表
|
||||
"""
|
||||
try:
|
||||
full_path = _resolve_safe_path("", dir_path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Directory not found: {dir_path}"
|
||||
}
|
||||
|
||||
if not os.path.isdir(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Not a directory: {dir_path}"
|
||||
}
|
||||
|
||||
items = []
|
||||
for item in os.listdir(full_path):
|
||||
item_path = os.path.join(full_path, item)
|
||||
is_dir = os.path.isdir(item_path)
|
||||
try:
|
||||
size = 0 if is_dir else os.path.getsize(item_path)
|
||||
except:
|
||||
size = 0
|
||||
|
||||
items.append({
|
||||
"name": item,
|
||||
"type": "directory" if is_dir else "file",
|
||||
"size": size
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"path": dir_path,
|
||||
"items": items,
|
||||
"count": len(items)
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"List error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def delete_file(file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
删除文件或目录
|
||||
|
||||
Args:
|
||||
file_path: 文件或目录路径
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
full_path = _resolve_safe_path("", file_path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Path not found: {file_path}"
|
||||
}
|
||||
|
||||
# 删除
|
||||
if os.path.isfile(full_path):
|
||||
os.remove(full_path)
|
||||
elif os.path.isdir(full_path):
|
||||
shutil.rmtree(full_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": file_path,
|
||||
"deleted": True
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Delete error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def search_files(
|
||||
directory: str,
|
||||
pattern: str = "*",
|
||||
content_pattern: Optional[str] = None,
|
||||
file_only: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
搜索文件
|
||||
|
||||
Args:
|
||||
directory: 搜索目录
|
||||
pattern: 文件名匹配模式 (glob)
|
||||
content_pattern: 文件内容匹配模式 (可选)
|
||||
file_only: 是否只返回文件
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
try:
|
||||
full_path = _resolve_safe_path("", directory)
|
||||
|
||||
if not os.path.exists(full_path) or not os.path.isdir(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid directory: {directory}"
|
||||
}
|
||||
|
||||
results = []
|
||||
|
||||
# 按文件名搜索
|
||||
for match in glob_module.glob(os.path.join(full_path, "**", pattern), recursive=True):
|
||||
if file_only and os.path.isdir(match):
|
||||
continue
|
||||
|
||||
rel_path = os.path.relpath(match, full_path)
|
||||
|
||||
# 如果没有内容搜索,直接添加
|
||||
if not content_pattern:
|
||||
results.append({
|
||||
"path": rel_path,
|
||||
"name": os.path.basename(match),
|
||||
"type": "directory" if os.path.isdir(match) else "file"
|
||||
})
|
||||
continue
|
||||
|
||||
# 内容搜索
|
||||
if os.path.isfile(match):
|
||||
try:
|
||||
# 检查文件大小
|
||||
if os.path.getsize(match) > FileToolConfig.MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
with open(match, "r", encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
if content_pattern.lower() in content.lower():
|
||||
results.append({
|
||||
"path": rel_path,
|
||||
"name": os.path.basename(match),
|
||||
"type": "file",
|
||||
"match": content_pattern
|
||||
})
|
||||
except:
|
||||
continue
|
||||
|
||||
# 限制结果数量
|
||||
if len(results) > FileToolConfig.MAX_SEARCH_RESULTS:
|
||||
results = results[:FileToolConfig.MAX_SEARCH_RESULTS]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"directory": directory,
|
||||
"pattern": pattern,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
READ_FILE_TOOL = {
|
||||
"name": "read_file",
|
||||
"description": "Read the contents of a file from the filesystem.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
|
||||
WRITE_FILE_TOOL = {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
}
|
||||
}
|
||||
|
||||
LIST_DIR_TOOL = {
|
||||
"name": "list_dir",
|
||||
"description": "List the contents of a directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dir_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the directory to list (default: current directory)",
|
||||
"default": "."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DELETE_FILE_TOOL = {
|
||||
"name": "delete_file",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file or directory to delete"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
|
||||
SEARCH_FILES_TOOL = {
|
||||
"name": "search_files",
|
||||
"description": "Search for files by name pattern or content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"directory": {
|
||||
"type": "string",
|
||||
"description": "The directory to search in"
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern for file names (e.g., '*.py', '*.txt')",
|
||||
"default": "*"
|
||||
},
|
||||
"content_pattern": {
|
||||
"type": "string",
|
||||
"description": "Optional: search for files containing this text in their content"
|
||||
},
|
||||
"file_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only return files, not directories",
|
||||
"default": True
|
||||
}
|
||||
},
|
||||
"required": ["directory"]
|
||||
}
|
||||
}
|
||||
@@ -1,271 +0,0 @@
|
||||
"""
|
||||
HTTP请求工具
|
||||
提供通用的HTTP API调用功能
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
|
||||
class HTTPClientConfig:
|
||||
"""HTTP客户端配置"""
|
||||
DEFAULT_TIMEOUT = 30 # 默认超时(秒)
|
||||
MAX_RESPONSE_SIZE = 5 * 1024 * 1024 # 最大响应大小(5MB)
|
||||
MAX_REDIRECTS = 5 # 最大重定向次数
|
||||
ALLOWED_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""
|
||||
HTTP客户端工具
|
||||
|
||||
安全特性:
|
||||
- 只允许特定HTTP方法
|
||||
- 响应大小限制
|
||||
- 超时控制
|
||||
- 请求/响应日志
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.default_timeout = HTTPClientConfig.DEFAULT_TIMEOUT
|
||||
|
||||
async def request(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Any] = None,
|
||||
timeout: Optional[int] = None,
|
||||
allow_redirects: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 目标URL
|
||||
method: HTTP方法
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
json_data: JSON请求体
|
||||
data: 原始请求体
|
||||
timeout: 超时时间
|
||||
allow_redirects: 是否允许重定向
|
||||
|
||||
Returns:
|
||||
响应结果
|
||||
"""
|
||||
# 安全检查:方法
|
||||
method = method.upper()
|
||||
if method not in HTTPClientConfig.ALLOWED_METHODS:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Method '{method}' not allowed. Allowed: {HTTPClientConfig.ALLOWED_METHODS}"
|
||||
}
|
||||
|
||||
# 安全检查:协议
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only HTTP and HTTPS protocols are allowed"
|
||||
}
|
||||
|
||||
timeout = timeout or self.default_timeout
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
max_redirects=HTTPClientConfig.MAX_REDIRECTS if allow_redirects else 0,
|
||||
follow_redirects=allow_redirects,
|
||||
) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
json=json_data,
|
||||
content=data,
|
||||
)
|
||||
|
||||
# 检查响应大小
|
||||
content_length = len(response.content)
|
||||
if content_length > HTTPClientConfig.MAX_RESPONSE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large: {content_length} bytes"
|
||||
}
|
||||
|
||||
# 解析响应
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"url": str(response.url),
|
||||
"headers": dict(response.headers),
|
||||
"json": response.json()
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
# 文本响应
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"url": str(response.url),
|
||||
"headers": dict(response.headers),
|
||||
"text": response.text[:HTTPClientConfig.MAX_RESPONSE_SIZE]
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Request timeout ({timeout}s)"
|
||||
}
|
||||
except httpx.InvalidURL:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Invalid URL"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送GET请求"""
|
||||
return await self.request(url, "GET", params, headers, timeout=timeout)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Any] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送POST请求"""
|
||||
return await self.request(url, "POST", None, headers, json_data, data, timeout)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送PUT请求"""
|
||||
return await self.request(url, "PUT", None, headers, json_data, None, timeout)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送DELETE请求"""
|
||||
return await self.request(url, "DELETE", None, headers, timeout=timeout)
|
||||
|
||||
|
||||
# 全局HTTP客户端
|
||||
http_client = HTTPClient()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def http_request(
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
return await http_client.request(url, method, params, headers, json_data, None, timeout)
|
||||
|
||||
|
||||
async def http_get(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送GET请求"""
|
||||
return await http_client.get(url, params, headers, timeout)
|
||||
|
||||
|
||||
async def http_post(
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送POST请求"""
|
||||
return await http_client.post(url, json_data, None, headers, timeout)
|
||||
|
||||
|
||||
async def http_put(
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送PUT请求"""
|
||||
return await http_client.put(url, json_data, headers, timeout)
|
||||
|
||||
|
||||
async def http_delete(
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送DELETE请求"""
|
||||
return await http_client.delete(url, headers, timeout)
|
||||
|
||||
|
||||
# 工具定义
|
||||
HTTP_REQUEST_TOOL = {
|
||||
"name": "http_request",
|
||||
"description": "Make HTTP requests to APIs. Supports GET, POST, PUT, DELETE methods with JSON data.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to request"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters for GET requests"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"json_data": {
|
||||
"type": "object",
|
||||
"description": "JSON body for POST/PUT requests"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
@@ -1,379 +0,0 @@
|
||||
"""
|
||||
通知工具
|
||||
提供发送通知的功能(邮件、Webhook等)
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型"""
|
||||
EMAIL = "email"
|
||||
WEBHOOK = "webhook"
|
||||
SMS = "sms"
|
||||
DINGTALK = "dingtalk"
|
||||
WECHAT = "wechat"
|
||||
SLACK = "slack"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationConfig:
|
||||
"""通知配置"""
|
||||
# Email配置
|
||||
smtp_host: str = ""
|
||||
smtp_port: int = 587
|
||||
smtp_user: str = ""
|
||||
smtp_password: str = ""
|
||||
from_email: str = ""
|
||||
|
||||
# Webhook配置
|
||||
webhook_url: str = ""
|
||||
webhook_secret: str = ""
|
||||
|
||||
# 钉钉配置
|
||||
dingtalk_webhook: str = ""
|
||||
|
||||
# Slack配置
|
||||
slack_webhook: str = ""
|
||||
|
||||
|
||||
class NotificationTool:
|
||||
"""
|
||||
通知工具
|
||||
|
||||
支持多种通知渠道:
|
||||
- Email (SMTP)
|
||||
- Webhook
|
||||
- 钉钉
|
||||
- Slack
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[NotificationConfig] = None):
|
||||
self.config = config or NotificationConfig()
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: Optional[List[str]] = None,
|
||||
is_html: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送邮件
|
||||
|
||||
Args:
|
||||
to: 收件人
|
||||
subject: 主题
|
||||
body: 内容
|
||||
cc: 抄送列表
|
||||
is_html: 是否HTML格式
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
if not self.config.smtp_host:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Email not configured"
|
||||
}
|
||||
|
||||
try:
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
# 构建邮件
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['Subject'] = subject
|
||||
msg['From'] = self.config.from_email or self.config.smtp_user
|
||||
msg['To'] = to
|
||||
|
||||
if cc:
|
||||
msg['Cc'] = ",".join(cc)
|
||||
|
||||
# 添加内容
|
||||
content_type = "html" if is_html else "plain"
|
||||
msg.attach(MIMEText(body, content_type))
|
||||
|
||||
# 发送
|
||||
with smtplib.SMTP(self.config.smtp_host, self.config.smtp_port) as server:
|
||||
server.starttls()
|
||||
server.login(self.config.smtp_user, self.config.smtp_password)
|
||||
server.send_message(msg)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"type": "email",
|
||||
"to": to,
|
||||
"subject": subject
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "email"
|
||||
}
|
||||
|
||||
async def send_webhook(
|
||||
self,
|
||||
url: str,
|
||||
data: Dict[str, Any],
|
||||
method: str = "POST",
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送Webhook
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
data: 请求数据
|
||||
method: HTTP方法
|
||||
headers: 请求头
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=data,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return {
|
||||
"success": response.status_code < 400,
|
||||
"status_code": response.status_code,
|
||||
"type": "webhook",
|
||||
"url": url
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "webhook"
|
||||
}
|
||||
|
||||
async def send_dingtalk(
|
||||
self,
|
||||
message: str,
|
||||
webhook: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送钉钉消息
|
||||
|
||||
Args:
|
||||
message: 消息内容
|
||||
webhook: 自定义webhook URL
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
url = webhook or self.config.dingtalk_webhook
|
||||
if not url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Dingtalk webhook not configured"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json={
|
||||
"msgtype": "text",
|
||||
"text": {
|
||||
"content": message
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
return {
|
||||
"success": result.get("errcode") == 0,
|
||||
"type": "dingtalk",
|
||||
"response": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "dingtalk"
|
||||
}
|
||||
|
||||
async def send_slack(
|
||||
self,
|
||||
message: str,
|
||||
channel: Optional[str] = None,
|
||||
webhook: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送Slack消息
|
||||
|
||||
Args:
|
||||
message: 消息内容
|
||||
channel: 频道
|
||||
webhook: 自定义webhook URL
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
url = webhook or self.config.slack_webhook
|
||||
if not url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Slack webhook not configured"
|
||||
}
|
||||
|
||||
try:
|
||||
payload = {"text": message}
|
||||
if channel:
|
||||
payload["channel"] = channel
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
|
||||
return {
|
||||
"success": response.status_code == 200,
|
||||
"type": "slack",
|
||||
"status_code": response.status_code
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "slack"
|
||||
}
|
||||
|
||||
async def send(
|
||||
self,
|
||||
type: str,
|
||||
message: str,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一发送接口
|
||||
|
||||
Args:
|
||||
type: 通知类型 (email, webhook, dingtalk, slack)
|
||||
message: 消息内容
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
type = type.lower()
|
||||
|
||||
if type == "email":
|
||||
return await self.send_email(
|
||||
to=kwargs.get("to", ""),
|
||||
subject=kwargs.get("subject", "Notification"),
|
||||
body=message,
|
||||
cc=kwargs.get("cc")
|
||||
)
|
||||
elif type == "webhook":
|
||||
return await self.send_webhook(
|
||||
url=kwargs.get("url", ""),
|
||||
data=kwargs.get("data", {"message": message})
|
||||
)
|
||||
elif type == "dingtalk":
|
||||
return await self.send_dingtalk(
|
||||
message=message,
|
||||
webhook=kwargs.get("webhook")
|
||||
)
|
||||
elif type == "slack":
|
||||
return await self.send_slack(
|
||||
message=message,
|
||||
channel=kwargs.get("channel"),
|
||||
webhook=kwargs.get("webhook")
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unknown notification type: {type}"
|
||||
}
|
||||
|
||||
|
||||
# 全局通知工具
|
||||
notification_tool = NotificationTool()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def send_notification(
|
||||
type: str,
|
||||
message: str,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""发送通知"""
|
||||
return await notification_tool.send(type, message, **kwargs)
|
||||
|
||||
|
||||
async def send_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str
|
||||
) -> Dict[str, Any]:
|
||||
"""发送邮件"""
|
||||
return await notification_tool.send_email(to, subject, body)
|
||||
|
||||
|
||||
async def send_webhook(
|
||||
url: str,
|
||||
data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""发送Webhook"""
|
||||
return await notification_tool.send_webhook(url, data)
|
||||
|
||||
|
||||
# 工具定义
|
||||
SEND_NOTIFICATION_TOOL = {
|
||||
"name": "send_notification",
|
||||
"description": "Send notifications via email, webhook, dingtalk, or slack.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "Notification type: email, webhook, dingtalk, slack",
|
||||
"enum": ["email", "webhook", "dingtalk", "slack"]
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The notification message"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "For email: recipient email address"
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "For email: email subject"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "For webhook: webhook URL"
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": "For webhook: JSON data to send"
|
||||
},
|
||||
"webhook": {
|
||||
"type": "string",
|
||||
"description": "Custom webhook URL for dingtalk/slack"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "For slack: channel name"
|
||||
}
|
||||
},
|
||||
"required": ["type", "message"]
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
时间工具
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def get_current_time(timezone: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前时间
|
||||
|
||||
Args:
|
||||
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
|
||||
|
||||
Returns:
|
||||
当前时间信息
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"datetime": now.isoformat(),
|
||||
"timestamp": now.timestamp(),
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"weekday": now.strftime("%A"),
|
||||
"timezone": timezone or "Local Time"
|
||||
}
|
||||
|
||||
|
||||
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> Dict[str, Any]:
|
||||
"""
|
||||
格式化时间戳
|
||||
|
||||
Args:
|
||||
timestamp: Unix 时间戳
|
||||
format_str: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的时间
|
||||
"""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return {
|
||||
"success": True,
|
||||
"formatted": dt.strftime(format_str),
|
||||
"datetime": dt.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
GET_CURRENT_TIME_TOOL = {
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,233 +0,0 @@
|
||||
"""
|
||||
网页获取工具
|
||||
提供安全的网页内容抓取功能
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class WebToolConfig:
|
||||
"""网页工具配置"""
|
||||
REQUEST_TIMEOUT = 30 # 请求超时(秒)
|
||||
MAX_RESPONSE_SIZE = 2 * 1024 * 1024 # 最大响应大小(2MB)
|
||||
MAX_REDIRECTS = 5 # 最大重定向次数
|
||||
ALLOWED_PROTOCOLS = ["http", "https"] # 允许的协议
|
||||
|
||||
|
||||
async def web_fetch(
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
body: Optional[str] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取网页内容
|
||||
|
||||
Args:
|
||||
url: 目标URL
|
||||
method: HTTP方法
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
网页内容
|
||||
"""
|
||||
timeout = timeout or WebToolConfig.REQUEST_TIMEOUT
|
||||
|
||||
# 安全检查:协议
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only HTTP and HTTPS protocols are allowed"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
max_redirects=WebToolConfig.MAX_REDIRECTS,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
# 发送请求
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
content=body,
|
||||
)
|
||||
|
||||
# 检查响应大小
|
||||
if len(response.content) > WebToolConfig.MAX_RESPONSE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large: {len(response.content)} bytes (max {WebToolConfig.MAX_RESPONSE_SIZE})"
|
||||
}
|
||||
|
||||
# 尝试解析JSON
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = response.json()
|
||||
return {
|
||||
"success": True,
|
||||
"url": str(response.url),
|
||||
"status_code": response.status_code,
|
||||
"content_type": content_type,
|
||||
"data": data,
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
# 返回文本
|
||||
return {
|
||||
"success": True,
|
||||
"url": str(response.url),
|
||||
"status_code": response.status_code,
|
||||
"content_type": content_type,
|
||||
"content": response.text[:WebToolConfig.MAX_RESPONSE_SIZE],
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Request timeout ({timeout}s)"
|
||||
}
|
||||
except httpx.RedirectLoop:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Too many redirects"
|
||||
}
|
||||
except httpx.InvalidURL:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Invalid URL"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def web_search(
|
||||
query: str,
|
||||
max_results: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
搜索网页
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.duckduckgo.com/",
|
||||
params={
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": 1,
|
||||
"skip_disambig": 1
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = []
|
||||
|
||||
if "RelatedTopics" in data:
|
||||
for item in data["RelatedTopics"][:max_results]:
|
||||
if "Text" in item:
|
||||
text = item.get("Text", "")
|
||||
results.append({
|
||||
"title": text.split(" - ")[0] if " - " in text else "",
|
||||
"content": text,
|
||||
"url": item.get("URL", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search API returned status {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
WEB_FETCH_TOOL = {
|
||||
"name": "web_fetch",
|
||||
"description": "Fetch content from a web URL. Supports GET, POST methods and can return JSON or text content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Request body (for POST)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
|
||||
WEB_SEARCH_TOOL = {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
"""
|
||||
工具注册表 - 管理所有可用工具(白名单机制)
|
||||
"""
|
||||
from typing import Any, Callable, Optional, Dict
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""工具安全等级"""
|
||||
SAFE = "safe" # 安全操作
|
||||
REVIEW = "review" # 需要审核
|
||||
DANGER = "danger" # 危险操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""工具元数据"""
|
||||
name: str
|
||||
description: str
|
||||
security_level: str
|
||||
require_approval: bool = False
|
||||
allowed_roles: list = None
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"security_level": self.security_level,
|
||||
"require_approval": self.require_approval
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, tuple[Callable, ToolMetadata]] = {}
|
||||
self._definitions: Dict[str, dict] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable,
|
||||
description: str = "",
|
||||
security_level: str = "safe",
|
||||
require_approval: bool = False,
|
||||
allowed_roles: list = None,
|
||||
parameters: dict = None
|
||||
):
|
||||
"""注册工具到白名单"""
|
||||
metadata = ToolMetadata(
|
||||
name=name,
|
||||
description=description,
|
||||
security_level=security_level,
|
||||
require_approval=require_approval,
|
||||
allowed_roles=allowed_roles or ["user", "admin"]
|
||||
)
|
||||
|
||||
self._tools[name] = (func, metadata)
|
||||
|
||||
# 生成工具定义(用于 LLM 调用)
|
||||
self._definitions[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
||||
"""获取工具函数和元数据"""
|
||||
if name not in self._tools:
|
||||
raise ValueError(f"Tool '{name}' not found in whitelist")
|
||||
return self._tools[name]
|
||||
|
||||
def get_tool_definition(self, name: str) -> Optional[dict]:
|
||||
"""获取工具定义(用于 LLM)"""
|
||||
return self._definitions.get(name)
|
||||
|
||||
def list_tools(self) -> list[ToolMetadata]:
|
||||
"""列出所有已注册工具"""
|
||||
return [meta for _, meta in self._tools.values()]
|
||||
|
||||
def list_definitions(self) -> list[dict]:
|
||||
"""列出所有工具定义(用于LLM)"""
|
||||
return list(self._definitions.values())
|
||||
|
||||
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
||||
"""检查用户权限"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return user_role in metadata.allowed_roles
|
||||
|
||||
def need_approval(self, tool_name: str) -> bool:
|
||||
"""判断是否需要审批"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return metadata.require_approval
|
||||
|
||||
|
||||
# 全局工具注册表
|
||||
global_registry = ToolRegistry()
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
沙盒模块
|
||||
"""
|
||||
from .sandbox import (
|
||||
Sandbox,
|
||||
SandboxConfig,
|
||||
SafeEval,
|
||||
sandbox,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Sandbox",
|
||||
"SandboxConfig",
|
||||
"SafeEval",
|
||||
"sandbox",
|
||||
]
|
||||
@@ -1,267 +0,0 @@
|
||||
"""
|
||||
沙盒执行环境 - 在项目内构建,不依赖 Docker
|
||||
提供安全的代码执行环境
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SandboxConfig:
|
||||
"""沙盒配置"""
|
||||
# 资源限制
|
||||
MAX_MEMORY_MB = 256 # 最大内存 (MB)
|
||||
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""
|
||||
沙盒执行器 - 使用 subprocess 隔离执行
|
||||
|
||||
安全特性:
|
||||
- 内存限制
|
||||
- CPU限制
|
||||
- 超时控制
|
||||
- 网络隔离(可选)
|
||||
- 临时文件隔离
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SandboxConfig] = None):
|
||||
self.config = config or SandboxConfig()
|
||||
self.temp_dir = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception as e:
|
||||
print(f"Cleanup error: {e}")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙盒中执行代码
|
||||
|
||||
Args:
|
||||
code: 要执行的代码
|
||||
language: 语言类型 (python, javascript)
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
if language == "python":
|
||||
return self._execute_python(code, timeout)
|
||||
elif language == "javascript":
|
||||
return self._execute_javascript(code, timeout)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unsupported language: {language}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 Python 代码"""
|
||||
# 创建临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 构建命令
|
||||
cmd = ["python", temp_file]
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir, # 限制工作目录
|
||||
env=self._get_restricted_env(), # 限制环境变量
|
||||
)
|
||||
|
||||
# 检查输出大小
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 JavaScript 代码"""
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 尝试使用 node
|
||||
cmd = ["node", temp_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _get_restricted_env(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取受限的环境变量
|
||||
移除敏感变量,保留必要的 PATH
|
||||
"""
|
||||
# 保留 PATH,移除其他敏感变量
|
||||
safe_env = {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir,
|
||||
"TMPDIR": self.temp_dir,
|
||||
}
|
||||
|
||||
# 移除可能不安全的变量
|
||||
unsafe_vars = [
|
||||
"PYTHONPATH",
|
||||
"PYTHONHOME",
|
||||
"LD_PRELOAD",
|
||||
"LD_LIBRARY_PATH",
|
||||
]
|
||||
|
||||
for var in unsafe_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
return safe_env
|
||||
|
||||
|
||||
class SafeEval:
|
||||
"""
|
||||
安全求值器 - 用于简单表达式计算
|
||||
比沙盒更轻量,适用于不需要完全隔离的场景
|
||||
"""
|
||||
|
||||
# 安全函数白名单
|
||||
SAFE_BUILTINS = {
|
||||
"abs": abs,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
"round": round,
|
||||
"pow": pow,
|
||||
"print": print,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
}
|
||||
|
||||
# 安全数学常量
|
||||
SAFE_MATH = {
|
||||
"pi": 3.14159265359,
|
||||
"e": 2.71828182846,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def eval(cls, expression: str) -> Any:
|
||||
"""
|
||||
安全地求值表达式
|
||||
|
||||
Args:
|
||||
expression: 数学表达式
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
# 预处理表达式
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
|
||||
# 构建安全命名空间
|
||||
safe_namespace = {
|
||||
**cls.SAFE_BUILTINS,
|
||||
**cls.SAFE_MATH,
|
||||
"__builtins__": {} # 禁用__builtins__
|
||||
}
|
||||
|
||||
try:
|
||||
result = eval(expression, safe_namespace)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Evaluation error: {e}")
|
||||
|
||||
|
||||
# 全局沙盒实例
|
||||
sandbox = Sandbox()
|
||||
@@ -1,347 +0,0 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"tools": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"description": "Read the contents of a file from the filesystem.",
|
||||
"category": "file",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.",
|
||||
"category": "file",
|
||||
"security_level": "review",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "list_dir",
|
||||
"description": "List the contents of a directory.",
|
||||
"category": "file",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dir_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the directory to list (default: current directory)",
|
||||
"default": "."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "delete_file",
|
||||
"description": "Delete a file or directory.",
|
||||
"category": "file",
|
||||
"security_level": "danger",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file or directory to delete"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "search_files",
|
||||
"description": "Search for files by name pattern or content.",
|
||||
"category": "file",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"directory": {
|
||||
"type": "string",
|
||||
"description": "The directory to search in"
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern for file names (e.g., '*.py', '*.txt')",
|
||||
"default": "*"
|
||||
},
|
||||
"content_pattern": {
|
||||
"type": "string",
|
||||
"description": "Optional: search for files containing this text in their content"
|
||||
},
|
||||
"file_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only return files, not directories",
|
||||
"default": true
|
||||
}
|
||||
},
|
||||
"required": ["directory"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "execute_python",
|
||||
"description": "Execute Python code in a sandboxed environment. Use this for Python programming tasks, calculations, and data processing.",
|
||||
"category": "executor",
|
||||
"security_level": "review",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The Python code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30, max: 60)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "execute_javascript",
|
||||
"description": "Execute JavaScript code in a sandboxed environment. Use this for JavaScript programming tasks.",
|
||||
"category": "executor",
|
||||
"security_level": "review",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The JavaScript code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "execute_bash",
|
||||
"description": "Execute a bash command in a sandboxed environment. Use this for shell operations, file management, and system commands.",
|
||||
"category": "executor",
|
||||
"security_level": "danger",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "web_fetch",
|
||||
"description": "Fetch content from a web URL. Supports GET, POST methods and can return JSON or text content.",
|
||||
"category": "web",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Request body (for POST)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"category": "web",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "http_request",
|
||||
"description": "Make HTTP requests to APIs. Supports GET, POST, PUT, DELETE methods with JSON data.",
|
||||
"category": "http",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to request"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters for GET requests"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"json_data": {
|
||||
"type": "object",
|
||||
"description": "JSON body for POST/PUT requests"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "send_notification",
|
||||
"description": "Send notifications via email, webhook, dingtalk, or slack.",
|
||||
"category": "notification",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "Notification type: email, webhook, dingtalk, slack",
|
||||
"enum": ["email", "webhook", "dingtalk", "slack"]
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The notification message"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "For email: recipient email address"
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "For email: email subject"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "For webhook: webhook URL"
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": "For webhook: JSON data to send"
|
||||
},
|
||||
"webhook": {
|
||||
"type": "string",
|
||||
"description": "Custom webhook URL for dingtalk/slack"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "For slack: channel name"
|
||||
}
|
||||
},
|
||||
"required": ["type", "message"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"category": "system",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
LLM 工厂 - 创建不同提供商的 LLM 实例
|
||||
"""
|
||||
from typing import Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""LLM 工厂类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000
|
||||
):
|
||||
self.provider = provider
|
||||
self.openai_api_key = openai_api_key
|
||||
self.anthropic_api_key = anthropic_api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self._llm = None
|
||||
|
||||
def get_llm(self):
|
||||
"""获取 LLM 实例"""
|
||||
if self._llm is not None:
|
||||
return self._llm
|
||||
|
||||
if self.provider == "openai":
|
||||
self._llm = ChatOpenAI(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
api_key=self.openai_api_key
|
||||
)
|
||||
elif self.provider == "anthropic":
|
||||
self._llm = ChatAnthropic(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
return self._llm
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型"""
|
||||
self.model = model
|
||||
self._llm = None # 重置 LLM 实例
|
||||
|
||||
def set_temperature(self, temperature: float):
|
||||
"""设置温度"""
|
||||
self.temperature = temperature
|
||||
if self._llm:
|
||||
self._llm.temperature = temperature
|
||||
@@ -1,58 +1,20 @@
|
||||
"""
|
||||
X-Agents Python Agent Service
|
||||
智能体引擎服务入口
|
||||
FastAPI Agent Engine Server
|
||||
"""
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
import time
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import routes
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.audit import AuditLogger
|
||||
from app.agent.core import AgentCore, Supervisor, AgentConfig
|
||||
from app.agent.llm import LLMFactory
|
||||
|
||||
|
||||
# 全局组件
|
||||
agent_manager: AgentManager = None
|
||||
audit_logger: AuditLogger = None
|
||||
app = FastAPI(title="X-Agents Python Engine", version="1.0.0")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
global agent_manager, audit_logger
|
||||
|
||||
# 启动时初始化
|
||||
audit_logger = AuditLogger()
|
||||
|
||||
# 初始化 Agent 管理器
|
||||
agent_manager = AgentManager(
|
||||
llm_provider=os.getenv("LLM_PROVIDER", "openai"),
|
||||
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
)
|
||||
|
||||
# 加载 Agent 配置
|
||||
await agent_manager.load_agents()
|
||||
|
||||
print("Agent service started successfully")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时清理
|
||||
print("Agent service shutting down")
|
||||
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="X-Agents Agent Service",
|
||||
description="AI Agent Engine for X-Agents Platform",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS 中间件
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -61,24 +23,180 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(routes.router, prefix="/agent", tags=["Agent"])
|
||||
|
||||
# === 请求/响应模型 ===
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""对话请求"""
|
||||
agent_id: int
|
||||
message: str
|
||||
user_id: int = 1
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "agent",
|
||||
"version": "1.0.0"
|
||||
class TeamChatRequest(BaseModel):
|
||||
"""多智能体群聊请求"""
|
||||
supervisor_agent_id: int
|
||||
member_agent_ids: list[int]
|
||||
message: str
|
||||
user_id: int = 1
|
||||
session_id: Optional[str] = None
|
||||
strategy: str = "parallel"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""对话响应"""
|
||||
agent_id: int
|
||||
response: str
|
||||
tool_calls: list = []
|
||||
tokens_used: int = 0
|
||||
duration_ms: int = 0
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
# === 模拟数据存储 ===
|
||||
# TODO: 后续替换为从数据库加载
|
||||
_mock_agents = {
|
||||
1: {
|
||||
"id": 1,
|
||||
"name": "数据分析助手",
|
||||
"role_description": "你是一个专业的数据分析助手,擅长分析数据、生成报告。",
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"skills": [1, 2]
|
||||
},
|
||||
2: {
|
||||
"id": 2,
|
||||
"name": "代码审查助手",
|
||||
"role_description": "你是一个专业的代码审查助手,擅长审查代码、发现bug。",
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"skills": [3]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_agent_config(agent_id: int) -> AgentConfig:
|
||||
"""获取智能体配置"""
|
||||
agent_data = _mock_agents.get(agent_id)
|
||||
if not agent_data:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
return AgentConfig(
|
||||
id=agent_data["id"],
|
||||
name=agent_data["name"],
|
||||
role_description=agent_data["role_description"],
|
||||
model_provider=agent_data["model_provider"],
|
||||
model_name=agent_data["model_name"],
|
||||
skills=agent_data.get("skills", [])
|
||||
)
|
||||
|
||||
|
||||
# === API 路由 ===
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {"message": "X-Agents Python Engine", "version": "1.0.0"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.post("/agent/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
"""
|
||||
单智能体对话
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 获取智能体配置
|
||||
try:
|
||||
config = get_agent_config(request.agent_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# 创建智能体实例
|
||||
agent = AgentCore(config)
|
||||
|
||||
# 生成 session_id
|
||||
session_id = request.session_id or f"session_{int(time.time())}"
|
||||
|
||||
# 执行对话
|
||||
try:
|
||||
result = await agent.run(request.message, request.user_id, session_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return ChatResponse(
|
||||
agent_id=request.agent_id,
|
||||
response=result.content,
|
||||
tool_calls=result.tool_calls,
|
||||
tokens_used=result.tokens_used,
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
|
||||
@app.post("/agent/team/chat")
|
||||
async def team_chat(request: TeamChatRequest):
|
||||
"""
|
||||
多智能体群聊
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 创建主智能体
|
||||
try:
|
||||
supervisor_config = get_agent_config(request.supervisor_agent_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
supervisor_agent = AgentCore(supervisor_config)
|
||||
|
||||
# 创建子智能体
|
||||
members = []
|
||||
for member_id in request.member_agent_ids:
|
||||
try:
|
||||
member_config = get_agent_config(member_id)
|
||||
members.append(AgentCore(member_config))
|
||||
except:
|
||||
continue
|
||||
|
||||
if not members:
|
||||
raise HTTPException(status_code=400, detail="No valid member agents")
|
||||
|
||||
# 创建调度器
|
||||
supervisor = Supervisor(supervisor_agent, members, request.strategy)
|
||||
|
||||
# 生成 session_id
|
||||
session_id = request.session_id or f"team_session_{int(time.time())}"
|
||||
|
||||
# 执行群聊
|
||||
try:
|
||||
result = await supervisor.run(request.message, request.user_id, session_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return {
|
||||
"message": "X-Agents Agent Service",
|
||||
"docs": "/docs"
|
||||
"supervisor_agent_id": request.supervisor_agent_id,
|
||||
"response": result["main_response"],
|
||||
"subtask_results": result["subtask_results"],
|
||||
"strategy": result["strategy"],
|
||||
"duration_ms": duration_ms,
|
||||
"session_id": session_id
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("AGENT_PORT", "8081"))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"""
|
||||
审批服务 - 处理工具执行的审批流程
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ApprovalStatus(Enum):
|
||||
"""审批状态"""
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ApprovalService:
|
||||
"""审批服务"""
|
||||
|
||||
def __init__(self):
|
||||
# 待审批队列
|
||||
self.pending: Dict[str, dict] = {}
|
||||
# 审批结果
|
||||
self.results: Dict[str, ApprovalStatus] = {}
|
||||
|
||||
async def request_approval(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
reason: str
|
||||
) -> str:
|
||||
"""
|
||||
请求审批
|
||||
|
||||
Returns:
|
||||
request_id: 审批请求ID
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = {
|
||||
"request_id": request_id,
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"reason": reason,
|
||||
"status": ApprovalStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.pending[request_id] = request
|
||||
self.results[request_id] = ApprovalStatus.PENDING
|
||||
|
||||
# TODO: 通知 Go 后端有新审批
|
||||
|
||||
return request_id
|
||||
|
||||
async def check_approval(self, request_id: str, timeout: int = 300) -> bool:
|
||||
"""
|
||||
检查审批状态
|
||||
|
||||
Args:
|
||||
request_id: 审批请求ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
是否已批准
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
start = datetime.now()
|
||||
|
||||
while (datetime.now() - start).seconds < timeout:
|
||||
status = self.results.get(request_id)
|
||||
|
||||
if status == ApprovalStatus.APPROVED:
|
||||
return True
|
||||
elif status == ApprovalStatus.REJECTED:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
raise TimeoutError("Approval request timeout")
|
||||
|
||||
async def approve(self, request_id: str):
|
||||
"""批准请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.APPROVED
|
||||
self.results[request_id] = ApprovalStatus.APPROVED
|
||||
|
||||
async def reject(self, request_id: str):
|
||||
"""拒绝请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.REJECTED
|
||||
self.results[request_id] = ApprovalStatus.REJECTED
|
||||
|
||||
def get_pending(self) -> list[dict]:
|
||||
"""获取待审批列表"""
|
||||
return [
|
||||
req for req in self.pending.values()
|
||||
if req["status"] == ApprovalStatus.PENDING
|
||||
]
|
||||
@@ -1,81 +0,0 @@
|
||||
"""
|
||||
审计日志 - 记录所有 Agent 操作
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""审计日志记录器"""
|
||||
|
||||
def __init__(self, log_file: str = "audit.log"):
|
||||
self.log_file = log_file
|
||||
|
||||
def log(
|
||||
self,
|
||||
action: str,
|
||||
agent_id: str = "",
|
||||
session_id: str = "",
|
||||
user_id: str = "",
|
||||
details: Dict[str, Any] = None,
|
||||
result: str = "success"
|
||||
):
|
||||
"""记录审计日志"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"action": action,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"details": details or {},
|
||||
"result": result
|
||||
}
|
||||
|
||||
# 写入文件
|
||||
self._write_log(entry)
|
||||
|
||||
# TODO: 发送到 Go 后端
|
||||
|
||||
def log_tool_execution(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: Dict[str, Any],
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
approved: bool,
|
||||
result: Any
|
||||
):
|
||||
"""记录工具执行"""
|
||||
self.log(
|
||||
action="tool_execution",
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
details={
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"approved": approved,
|
||||
"result_preview": str(result)[:200] if result else None
|
||||
},
|
||||
result="approved" if approved else "pending_approval"
|
||||
)
|
||||
|
||||
def log_error(self, action: str, error: str, **kwargs):
|
||||
"""记录错误"""
|
||||
self.log(
|
||||
action=action,
|
||||
details={"error": error, **kwargs},
|
||||
result="error"
|
||||
)
|
||||
|
||||
def _write_log(self, entry: dict):
|
||||
"""写入日志文件"""
|
||||
try:
|
||||
log_path = Path(self.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
print(f"Failed to write audit log: {e}")
|
||||
@@ -1,19 +1,8 @@
|
||||
# 核心依赖
|
||||
fastapi>=0.100.0
|
||||
uvicorn>=0.20.0
|
||||
uvicorn[standard]>=0.23.0
|
||||
pydantic>=2.0.0
|
||||
httpx>=0.24.0
|
||||
aiohttp>=3.8.0
|
||||
python-multipart>=0.0.5
|
||||
|
||||
# LLM 支持
|
||||
openai>=1.0.0
|
||||
anthropic>=0.18.0
|
||||
langchain-core>=0.1.0
|
||||
langchain-openai>=0.0.2
|
||||
|
||||
# 可选:向量数据库
|
||||
chromadb>=0.4.0
|
||||
|
||||
# Redis
|
||||
redis>=4.5.0
|
||||
python-dotenv>=1.0.0
|
||||
aiohttp>=3.8.0
|
||||
redis>=5.0.0
|
||||
|
||||
Reference in New Issue
Block a user