refactor: 重构 Agent 模块

- 删除旧的 agent 核心文件
- 新增 supervisor, memory, skills 等模块
- 重构 main.py 服务入口

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-11 16:25:37 +08:00
parent b5b2c32477
commit c6a4b28bf6
43 changed files with 385 additions and 6052 deletions

View File

@@ -1,28 +0,0 @@
# Python Agent Service Dockerfile
FROM python:3.11-slim
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装 Python 依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY app/ ./app/
# 创建数据目录
RUN mkdir -p /app/data
# 暴露端口
EXPOSE 8081
# 启动服务
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8081"]

View File

@@ -1,192 +1,123 @@
"""
Agent 核心管理器
Agent Core - 单智能体核心
"""
import os
from typing import Any, Optional
from app.agent.core.executor import AgentExecutor
from app.agent.memory.session import SessionManager
from app.agent.tools.registry import ToolRegistry
from app.llm.factory import LLMFactory
from app.security.audit import AuditLogger
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
from app.agent.memory.manager import MemoryManager
from app.agent.skills.router import SkillRouter
from app.agent.skills.executor import SkillExecutor
from app.agent.llm.factory import LLMFactory
class AgentManager:
"""Agent 管理器 - 负责加载和管理所有 Agent"""
class AgentConfig(BaseModel):
"""智能体配置"""
id: int
name: str
role_description: str
model_provider: str = "openai"
model_name: str = "gpt-4"
skills: List[int] = [] # 技能 ID 列表
knowledge_base_ids: List[int] = []
timeout: int = 60
memory_limit: int = 134217728 # 128MB
def __init__(
self,
llm_provider: str = "openai",
openai_api_key: Optional[str] = None,
anthropic_api_key: Optional[str] = None,
):
self.llm_provider = llm_provider
self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
self.anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
# 初始化组件
self.llm_factory = LLMFactory(
provider=llm_provider,
openai_api_key=self.openai_api_key,
anthropic_api_key=self.anthropic_api_key
)
self.tool_registry = ToolRegistry()
self.session_manager = SessionManager()
self.audit_logger = AuditLogger()
class AgentResponse(BaseModel):
"""智能体响应"""
content: str
tool_calls: List[Dict[str, Any]] = []
tokens_used: int = 0
duration_ms: int = 0
session_id: Optional[str] = None
# 已加载的 Agent
self.agents: dict[str, dict] = {}
self.executors: dict[str, AgentExecutor] = {}
# 注册默认工具
self._register_default_tools()
class AgentCore:
"""单智能体核心类"""
def _register_default_tools(self):
"""注册默认工具"""
from app.agent.tools.impl import search, calculator, time_tool
from app.agent.tools.impl import sandbox, database, api_client
def __init__(self, config: AgentConfig):
self.config = config
self.llm = LLMFactory.create(config.model_provider, config.model_name)
self.memory = MemoryManager(config.id)
self.skill_router = SkillRouter(config.skills)
self.skill_executor = SkillExecutor()
# 安全工具 - Safe 级别
self.tool_registry.register(
name="search",
func=search.search_web,
description="Search the web for information",
security_level="safe"
)
async def run(self, user_input: str, user_id: int, session_id: str) -> AgentResponse:
"""
执行智能体对话
self.tool_registry.register(
name="calculator",
func=calculator.calculate,
description="Perform mathematical calculations",
security_level="safe"
)
Args:
user_input: 用户输入
user_id: 用户 ID
session_id: 会话 ID
self.tool_registry.register(
name="get_current_time",
func=time_tool.get_current_time,
description="Get current date and time",
security_level="safe"
)
Returns:
AgentResponse: 智能体响应
"""
import time
start_time = time.time()
# 需要审核的工具 - Review 级别
self.tool_registry.register(
name="execute_code",
func=sandbox.sandbox.execute,
description="Execute code in sandbox (Python/JavaScript)",
security_level="review",
require_approval=True,
parameters={
"type": "object",
"properties": {
"code": {"type": "string", "description": "Code to execute"},
"language": {"type": "string", "default": "python"},
"timeout": {"type": "integer", "default": 30}
},
"required": ["code"]
}
)
try:
# 1. 加载记忆
context = await self.memory.load_context(user_input, user_id, session_id)
self.tool_registry.register(
name="query_database",
func=database.query_data,
description="Query database (SELECT only)",
security_level="review",
require_approval=True,
parameters={
"type": "object",
"properties": {
"sql": {"type": "string", "description": "SELECT query"}
},
"required": ["sql"]
}
)
# 2. 构建 Prompt
prompt = self._build_prompt(user_input, context)
self.tool_registry.register(
name="call_api",
func=api_client.call_api,
description="Call external API (whitelist only)",
security_level="review",
require_approval=True,
parameters={
"type": "object",
"properties": {
"api_name": {"type": "string"},
"endpoint": {"type": "string"},
"params": {"type": "object"}
},
"required": ["api_name"]
}
)
# 3. LLM 决策
decision = await self.llm.decide(prompt)
async def load_agents(self):
"""加载 Agent 配置"""
# TODO: 从数据库或配置文件加载
# 这里先注册一些示例 Agent
# 4. 执行技能(如需)
if decision.get('needs_skill'):
skill_results = await self._execute_skills(decision.get('tool_calls', []))
# 5. 基于结果生成回复
final_response = await self.llm.generate(prompt, skill_results)
else:
final_response = decision.get('response', '')
self.agents["assistant"] = {
"name": "General Assistant",
"description": "A general purpose assistant",
"system_prompt": "You are a helpful assistant.",
"tools": ["search", "calculator", "get_current_time"]
}
# 6. 保存记忆
await self.memory.save(user_input, final_response, user_id, session_id)
self.agents["coder"] = {
"name": "Code Assistant",
"description": "Helps with coding tasks",
"system_prompt": "You are a helpful coding assistant. You can write, explain, and debug code.",
"tools": ["search", "calculator"]
}
duration_ms = int((time.time() - start_time) * 1000)
# 为每个 Agent 创建执行器
for agent_id, config in self.agents.items():
self.executors[agent_id] = AgentExecutor(
agent_id=agent_id,
llm_factory=self.llm_factory,
tool_registry=self.tool_registry,
session_manager=self.session_manager,
audit_logger=self.audit_logger,
config=config
return AgentResponse(
content=final_response,
tool_calls=decision.get('tool_calls', []),
duration_ms=duration_ms,
session_id=session_id
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
return AgentResponse(
content=f"处理请求时发生错误: {str(e)}",
duration_ms=duration_ms,
session_id=session_id
)
async def execute(
self,
agent_id: str,
message: str,
session_id: str,
context: dict = None
) -> dict[str, Any]:
"""执行 Agent"""
if agent_id not in self.executors:
raise ValueError(f"Agent '{agent_id}' not found")
def _build_prompt(self, user_input: str, context: dict) -> str:
"""构建 Prompt"""
system_prompt = f"""你是 {self.config.name}
{self.config.role_description}
executor = self.executors[agent_id]
相关记忆:
{context.get('summary', '')}
# 执行
result = await executor.run(
message=message,
session_id=session_id,
context=context or {}
)
知识库信息:
{context.get('knowledge', '')}
return result
请根据以上上下文回答用户问题。如果需要使用工具,请明确说明。"""
def list_tools(self) -> list:
"""列出所有可用工具"""
return self.tool_registry.list_tools()
return f"{system_prompt}\n\n用户: {user_input}"
def list_agents(self) -> list[dict]:
"""列出所有 Agent"""
return [
{
"id": agent_id,
"name": config["name"],
"description": config["description"]
}
for agent_id, config in self.agents.items()
]
async def _execute_skills(self, skill_decisions: List[Dict]) -> List[Dict]:
"""执行技能"""
if not skill_decisions:
return []
def get_agent_info(self, agent_id: str) -> Optional[dict]:
"""获取 Agent 信息"""
if agent_id not in self.agents:
return None
return self.agents[agent_id]
results = []
for decision in skill_decisions:
result = await self.skill_executor.execute(
skill_id=decision.get('skill_id'),
params=decision.get('params', {})
)
results.append(result)
return results

View File

@@ -1,163 +0,0 @@
"""
Agent 执行器 - 负责执行 Agent 的核心逻辑
"""
from typing import Any, Optional
from app.llm.factory import LLMFactory
from app.agent.tools.registry import ToolRegistry
from app.agent.memory.session import SessionManager
from app.security.audit import AuditLogger
class AgentExecutor:
"""Agent 执行器"""
def __init__(
self,
agent_id: str,
llm_factory: LLMFactory,
tool_registry: ToolRegistry,
session_manager: SessionManager,
audit_logger: AuditLogger,
config: dict
):
self.agent_id = agent_id
self.llm_factory = llm_factory
self.tool_registry = tool_registry
self.session_manager = session_manager
self.audit_logger = audit_logger
self.config = config
# 获取 LLM
self.llm = self.llm_factory.get_llm()
async def run(
self,
message: str,
session_id: str,
context: dict
) -> dict[str, Any]:
"""运行 Agent"""
tools_used = []
# 1. 获取会话历史
history = self.session_manager.get_history(session_id)
# 2. 构建消息列表
messages = self._build_messages(message, history)
# 3. 获取可用工具
available_tools = self._get_available_tools()
# 4. 调用 LLM带工具
try:
response = await self.llm.agenerate(
messages=messages,
tools=available_tools
)
# 检查是否需要调用工具
response_message = response.generations[0][0]
# 如果有工具调用
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
for tool_call in response_message.tool_calls:
tool_name = tool_call.name
tool_args = tool_call.arguments
# 记录工具使用
tools_used.append(tool_name)
# 执行工具
tool_result = await self._execute_tool(tool_name, tool_args)
# 添加工具结果到消息
messages.append(response_message)
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": str(tool_result)
})
# 再次调用 LLM 生成最终响应
final_response = await self.llm.agenerate(messages=messages)
final_message = final_response.generations[0][0].text
# 保存到历史
self.session_manager.add_message(session_id, "user", message)
self.session_manager.add_message(session_id, "assistant", final_message)
return {
"reply": final_message,
"tools_used": tools_used,
"metadata": {}
}
# 没有工具调用,直接返回
reply = response_message.text
# 保存到历史
self.session_manager.add_message(session_id, "user", message)
self.session_manager.add_message(session_id, "assistant", reply)
return {
"reply": reply,
"tools_used": tools_used,
"metadata": {}
}
except Exception as e:
# 记录错误
self.audit_logger.log(
action="agent_error",
agent_id=self.agent_id,
session_id=session_id,
details={"error": str(e)}
)
raise
def _build_messages(self, message: str, history: list) -> list:
"""构建消息列表"""
messages = []
# 添加系统提示
system_prompt = self.config.get("system_prompt", "You are a helpful assistant.")
messages.append({"role": "system", "content": system_prompt})
# 添加历史
for msg in history:
messages.append(msg)
# 添加当前消息
messages.append({"role": "user", "content": message})
return messages
def _get_available_tools(self) -> list:
"""获取可用工具定义"""
agent_tools = self.config.get("tools", [])
tool_defs = []
for tool_name in agent_tools:
tool_def = self.tool_registry.get_tool_definition(tool_name)
if tool_def:
tool_defs.append(tool_def)
return tool_defs
async def _execute_tool(self, tool_name: str, args: dict) -> Any:
"""执行工具"""
# 安全检查
tool_func, metadata = self.tool_registry.get_tool(tool_name)
# 如果需要审批,抛出异常
if metadata.require_approval:
raise PermissionError(
f"Tool '{tool_name}' requires approval before execution"
)
# 执行工具
try:
result = tool_func(**args)
return result
except Exception as e:
return {"error": str(e)}

View File

@@ -1,62 +1,125 @@
"""
会话管理器 - 管理 Agent 的会话历史
Session Memory - 会话级记忆Redis 存储)
"""
from typing import Any, Optional
from collections import defaultdict
from datetime import datetime
import json
from typing import Optional
class SessionManager:
"""会话管理器"""
class SessionMemory:
"""会话级记忆Redis 存储"""
def __init__(self, max_history: int = 10):
def __init__(self, agent_id: int, redis_client=None):
"""
初始化会话管理器
初始化会话记忆
Args:
max_history: 每个会话保留的最大历史消息数
agent_id: 智能体 ID
redis_client: Redis 客户端(可选)
"""
self.max_history = max_history
self.sessions: dict[str, list[dict]] = defaultdict(list)
self.metadata: dict[str, dict] = {}
self.agent_id = agent_id
self.redis = redis_client
self.ttl = 3600 * 24 # 24 小时
self.summary_threshold = 10 # 多少条消息后生成摘要
def add_message(self, session_id: str, role: str, content: str):
"""添加消息到会话"""
self.sessions[session_id].append({
"role": role,
"content": content,
"timestamp": datetime.now().isoformat()
})
def _key(self, user_id: int, session_id: str) -> str:
"""生成 Redis Key"""
return f"agent:memory:session:{self.agent_id}:{user_id}:{session_id}"
# 限制历史长度
if len(self.sessions[session_id]) > self.max_history:
self.sessions[session_id] = self.sessions[session_id][-self.max_history:]
async def add(self, user_input: str, response: str, user_id: int, session_id: str):
"""
添加对话到会话记忆
def get_history(self, session_id: str) -> list[dict]:
"""获取会话历史"""
return self.sessions.get(session_id, [])
Args:
user_input: 用户输入
response: 智能体回复
user_id: 用户 ID
session_id: 会话 ID
"""
if not self.redis:
# 如果没有 Redis使用内存模拟
return self._add_memory(user_input, response, user_id, session_id)
def clear_session(self, session_id: str):
"""清除会话"""
if session_id in self.sessions:
del self.sessions[session_id]
if session_id in self.metadata:
del self.metadata[session_id]
key = self._key(user_id, session_id)
def set_metadata(self, session_id: str, key: str, value: Any):
"""设置会话元数据"""
if session_id not in self.metadata:
self.metadata[session_id] = {}
self.metadata[session_id][key] = value
# 获取现有数据
data = await self.redis.get(key)
messages = json.loads(data) if data else {"messages": [], "summary": ""}
def get_metadata(self, session_id: str, key: str, default: Any = None) -> Any:
"""获取会话元数据"""
return self.metadata.get(session_id, {}).get(key, default)
# 添加新消息
messages["messages"].append({"role": "user", "content": user_input})
messages["messages"].append({"role": "assistant", "content": response})
def list_sessions(self) -> list[str]:
"""列出所有会话ID"""
return list(self.sessions.keys())
# 定期生成摘要
if len(messages["messages"]) >= self.summary_threshold:
messages["summary"] = await self._generate_summary(messages["messages"])
def get_session_count(self) -> int:
"""获取会话数量"""
return len(self.sessions)
# 保持消息数量
if len(messages["messages"]) > 50:
messages["messages"] = messages["messages"][-50:]
await self.redis.setex(key, self.ttl, json.dumps(messages))
async def get_summary(self, user_id: int, session_id: str) -> str:
"""
获取会话摘要
Args:
user_id: 用户 ID
session_id: 会话 ID
Returns:
str: 会话摘要
"""
if not self.redis:
return self._get_memory_summary(user_id, session_id)
key = self._key(user_id, session_id)
data = await self.redis.get(key)
if data:
messages = json.loads(data)
return messages.get("summary", "")
return ""
async def _generate_summary(self, messages: list) -> str:
"""
生成摘要(简化版)
Args:
messages: 消息列表
Returns:
str: 摘要
"""
# 简化:取最后几条消息的要点
if not messages:
return ""
recent = messages[-6:] # 最近 3 轮
summary = f"最近对话包含 {len(messages)//2} 轮交互"
# TODO: 后续可以使用 LLM 生成更好的摘要
return summary
# === 内存模拟(无 Redis 时使用)===
_memory_store = {}
def _add_memory(self, user_input: str, response: str, user_id: int, session_id: str):
"""内存模拟 - 添加"""
key = f"{self.agent_id}:{user_id}:{session_id}"
if key not in self._memory_store:
self._memory_store[key] = {"messages": [], "summary": ""}
messages = self._memory_store[key]["messages"]
messages.append({"role": "user", "content": user_input})
messages.append({"role": "assistant", "content": response})
if len(messages) >= self.summary_threshold:
self._memory_store[key]["summary"] = self._generate_summary(messages)
def _get_memory_summary(self, user_id: int, session_id: str) -> str:
"""内存模拟 - 获取摘要"""
key = f"{self.agent_id}:{user_id}:{session_id}"
if key in self._memory_store:
return self._memory_store[key].get("summary", "")
return ""

View File

@@ -1,23 +0,0 @@
"""
多智能体系统
"""
from .types import AgentState, TaskItem, TaskStatus, AgentType, SupervisorDecision, ReviewResult
from .prompts import SUPERVISOR_SYSTEM_PROMPT, REVIEW_SYSTEM_PROMPT, RESEARCH_SYSTEM_PROMPT, CODER_SYSTEM_PROMPT, AGGREGATOR_SYSTEM_PROMPT
from .supervisor import SupervisorAgent
from .graph import create_multi_agent_graph
__all__ = [
"AgentState",
"TaskItem",
"TaskStatus",
"AgentType",
"SupervisorDecision",
"ReviewResult",
"SUPERVISOR_SYSTEM_PROMPT",
"REVIEW_SYSTEM_PROMPT",
"RESEARCH_SYSTEM_PROMPT",
"CODER_SYSTEM_PROMPT",
"AGGREGATOR_SYSTEM_PROMPT",
"SupervisorAgent",
"create_multi_agent_graph",
]

View File

@@ -1,130 +0,0 @@
"""
LangGraph 流程编排
"""
from langgraph.graph import StateGraph, END
from langgraph.graph.graph import CompiledGraph
from .types import AgentState, AgentType
from .supervisor import SupervisorAgent, ResultAggregator
from .workers.research import ResearchWorker
from .workers.coder import CoderWorker
from .workers.review import ReviewWorker
def create_multi_agent_graph(
llm,
tool_registry=None,
max_iterations: int = 3,
max_tasks: int = 10
) -> CompiledGraph:
"""创建多 Agent 流程图
Args:
llm: 语言模型实例
tool_registry: 工具注册表
max_iterations: 最大迭代次数
max_tasks: 最大任务数
Returns:
CompiledGraph: 编译后的 LangGraph
"""
# 初始化组件
supervisor = SupervisorAgent(llm, max_iterations=max_iterations, max_tasks=max_tasks)
research_worker = ResearchWorker(llm, tool_registry)
coder_worker = CoderWorker(llm, tool_registry)
review_worker = ReviewWorker(llm, tool_registry)
aggregator = ResultAggregator(llm)
# 创建图
graph = StateGraph(AgentState)
# 添加节点
graph.add_node("supervisor", supervisor.create_node())
graph.add_node(AgentType.RESEARCH, research_worker.create_node())
graph.add_node(AgentType.CODER, coder_worker.create_node())
graph.add_node(AgentType.REVIEW, review_worker.create_node())
graph.add_node("aggregator", aggregator.create_node())
# 设置入口点
graph.set_entry_point("supervisor")
# 定义条件边函数
def should_continue(state: AgentState) -> str:
"""判断是否继续执行"""
# 获取下一步节点
next_node = state.get("next_node", "aggregator")
# 如果是结束节点
if next_node in ["__end__", "aggregator"]:
return "aggregator"
# 如果是 Worker 节点
if next_node in [AgentType.RESEARCH, AgentType.CODER, AgentType.REVIEW]:
return next_node
# 如果是 supervisor
if next_node == "supervisor":
# 检查迭代次数
iteration = state.get("iteration", 0)
if iteration >= max_iterations:
return "aggregator"
return "supervisor"
# 默认进入汇总
return "aggregator"
# 添加条件边:从 supervisor 出来
graph.add_conditional_edges(
"supervisor",
should_continue,
{
"supervisor": "supervisor",
AgentType.RESEARCH: AgentType.RESEARCH,
AgentType.CODER: AgentType.CODER,
AgentType.REVIEW: AgentType.REVIEW,
"aggregator": "aggregator"
}
)
# 添加边Worker -> Review
graph.add_edge(AgentType.RESEARCH, AgentType.REVIEW)
graph.add_edge(AgentType.CODER, AgentType.REVIEW)
# 添加条件边:从 Review 出来
graph.add_conditional_edges(
AgentType.REVIEW,
should_continue,
{
"supervisor": "supervisor",
"aggregator": "aggregator"
}
)
# 添加边aggregator -> END
graph.add_edge("aggregator", END)
# 编译图
return graph.compile()
def create_simple_graph(llm, tool_registry=None) -> CompiledGraph:
"""创建简单的单 Agent 图(不经过 Supervisor"""
# 创建图
graph = StateGraph(AgentState)
# 直接使用 Coder Worker
coder_worker = CoderWorker(llm, tool_registry)
# 添加节点
graph.add_node("coder", coder_worker.create_node())
# 设置入口
graph.set_entry_point("coder")
# 添加边
graph.add_edge("coder", END)
return graph.compile()

View File

@@ -1,223 +0,0 @@
"""
多智能体系统 - 与现有系统集成
"""
import logging
from typing import Optional
from app.llm.factory import LLMFactory
from app.agent.tools.registry import ToolRegistry
from app.agent.memory.session import SessionManager
from .types import create_initial_state
from .graph import create_multi_agent_graph
logger = logging.getLogger(__name__)
class MultiAgentSystem:
"""多智能体系统 - 集成现有组件"""
def __init__(
self,
llm_provider: str = "openai",
openai_api_key: Optional[str] = None,
anthropic_api_key: Optional[str] = None,
max_iterations: int = 3,
max_tasks: int = 10
):
"""
初始化多智能体系统
Args:
llm_provider: LLM 提供商
openai_api_key: OpenAI API Key
anthropic_api_key: Anthropic API Key
max_iterations: 最大迭代次数
max_tasks: 最大任务数
"""
# 初始化 LLM Factory
self.llm_factory = LLMFactory(
provider=llm_provider,
openai_api_key=openai_api_key,
anthropic_api_key=anthropic_api_key
)
# 初始化 Tool Registry
self.tool_registry = ToolRegistry()
self._register_default_tools()
# 初始化 Session Manager
self.session_manager = SessionManager()
# 配置
self.max_iterations = max_iterations
self.max_tasks = max_tasks
# 图实例(延迟初始化)
self._graph = None
def _register_default_tools(self):
"""注册默认工具"""
try:
from app.agent.tools.impl import search, calculator, time_tool
# 安全工具
self.tool_registry.register(
name="search",
func=search.search_web,
description="Search the web for information",
security_level="safe"
)
self.tool_registry.register(
name="calculator",
func=calculator.calculate,
description="Perform mathematical calculations",
security_level="safe"
)
self.tool_registry.register(
name="get_current_time",
func=time_tool.get_current_time,
description="Get current date and time",
security_level="safe"
)
# 执行代码工具
try:
from app.agent.tools.impl import sandbox
self.tool_registry.register(
name="execute_code",
func=sandbox.sandbox.execute,
description="Execute code in sandbox",
security_level="review",
require_approval=True
)
except ImportError:
pass
except ImportError as e:
logger.warning(f"Failed to import default tools: {e}")
@property
def graph(self):
"""获取或创建 LangGraph"""
if self._graph is None:
llm = self.llm_factory.get_llm()
self._graph = create_multi_agent_graph(
llm=llm,
tool_registry=self.tool_registry,
max_iterations=self.max_iterations,
max_tasks=self.max_tasks
)
return self._graph
async def execute(self, task: str, session_id: str = None) -> dict:
"""
执行多 Agent 任务
Args:
task: 任务描述
session_id: 会话 ID可选
Returns:
dict: 执行结果
"""
# 创建初始状态
initial_state = create_initial_state(task, session_id)
try:
# 执行图
result = await self.graph.ainvoke(initial_state)
# 保存到 session
if session_id:
self.session_manager.add_message(session_id, "user", task)
self.session_manager.add_message(
session_id,
"assistant",
result.get("final_output", "")
)
return {
"success": result.get("status") != "failed",
"output": result.get("final_output", ""),
"status": result.get("status", "unknown"),
"task_plan": result.get("task_plan", []),
"results": result.get("results", {})
}
except Exception as e:
logger.error(f"Multi-agent execution failed: {e}")
return {
"success": False,
"output": f"执行失败: {str(e)}",
"status": "failed",
"error": str(e)
}
async def execute_simple(self, task: str, session_id: str = None) -> dict:
"""
执行简单任务(不使用 Supervisor
Args:
task: 任务描述
session_id: 会话 ID可选
Returns:
dict: 执行结果
"""
from .graph import create_simple_graph
# 创建简单图
llm = self.llm_factory.get_llm()
simple_graph = create_simple_graph(llm, self.tool_registry)
# 创建初始状态
initial_state = create_initial_state(task, session_id)
try:
# 执行图
result = await simple_graph.ainvoke(initial_state)
return {
"success": True,
"output": result.get("final_output", ""),
"status": result.get("status", "completed")
}
except Exception as e:
logger.error(f"Simple execution failed: {e}")
return {
"success": False,
"output": f"执行失败: {str(e)}",
"status": "failed",
"error": str(e)
}
def list_tools(self) -> list:
"""列出所有可用工具"""
return self.tool_registry.list_tools()
# 全局实例
_global_system: Optional[MultiAgentSystem] = None
def get_multi_agent_system(
llm_provider: str = "openai",
openai_api_key: str = None,
anthropic_api_key: str = None,
**kwargs
) -> MultiAgentSystem:
"""获取全局多智能体系统实例"""
global _global_system
if _global_system is None:
_global_system = MultiAgentSystem(
llm_provider=llm_provider,
openai_api_key=openai_api_key,
anthropic_api_key=anthropic_api_key,
**kwargs
)
return _global_system

View File

@@ -1,117 +0,0 @@
"""
迭代控制器
"""
from typing import Optional
class IterationController:
"""迭代控制器 - 管理任务执行的迭代"""
def __init__(
self,
max_iterations: int = 3,
max_retries_per_task: int = 2
):
"""
初始化迭代控制器
Args:
max_iterations: 全局最大迭代次数
max_retries_per_task: 每个任务的最大重试次数
"""
self.max_iterations = max_iterations
self.max_retries_per_task = max_retries_per_task
def should_continue(
self,
iteration: int,
task_status: str,
review_result: Optional[dict] = None
) -> tuple[bool, str]:
"""
判断是否继续迭代
Args:
iteration: 当前迭代次数
task_status: 任务状态
review_result: 评审结果(可选)
Returns:
(是否继续, 原因)
"""
# 超过最大迭代次数
if iteration >= self.max_iterations:
return False, "max_iterations_reached"
# 任务成功完成
if task_status == "completed":
if review_result and review_result.get("passed"):
return False, "task_completed"
elif review_result is None:
return False, "task_completed"
# 任务失败且不可重试
if task_status == "failed":
if review_result and not review_result.get("retryable", True):
return False, "task_failed_non_retryable"
# 检查重试次数
retry_count = review_result.get("retry_count", 0) if review_result else 0
if retry_count >= self.max_retries_per_task:
return False, "max_retries_reached"
# 需要重试
if review_result:
issues = review_result.get("issues", [])
if issues and not review_result.get("passed", True):
return True, "needs_retry"
return True, "continue"
def get_next_action(
self,
review_result: Optional[dict],
current_worker: str
) -> str:
"""
确定下一步动作
Args:
review_result: 评审结果
current_worker: 当前执行的 Worker
Returns:
下一个节点名称
"""
if review_result is None:
return "supervisor"
# 根据评审结果决定下一步
if review_result.get("passed"):
return "supervisor"
# 根据问题类型决定下一步
issues = review_result.get("issues", [])
high_severity = any(i.get("severity") == "high" for i in issues)
if high_severity:
# 严重问题,重新执行相同任务
return current_worker
else:
# 轻微问题,返回 Supervisor
return "supervisor"
def calculate_backoff_delay(self, retry_count: int) -> float:
"""
计算退避延迟(指数退避)
Args:
retry_count: 重试次数
Returns:
延迟时间(秒)
"""
base_delay = 1.0
max_delay = 30.0
delay = min(base_delay * (2 ** retry_count), max_delay)
return delay

View File

@@ -1,170 +0,0 @@
"""
多智能体系统 Prompt 模板
"""
# Supervisor System Prompt
SUPERVISOR_SYSTEM_PROMPT = """你是一个任务规划专家Supervisor。你的职责是将复杂任务分解为可执行的子任务并分配给合适的执行 Agent。
## 可用的 Worker Agent
- **research**: 信息搜索和调研
- **coder**: 代码编写、修改和调试
- **review**: 结果检查、质量评审
## 任务
{task}
## 当前进度
{progress}
## 共享上下文
{context}
## 请按以下步骤执行
### 步骤 1: 任务分析
分析任务的性质,确定需要哪些步骤来完成。
### 步骤 2: 任务分解
将任务分解为独立的子任务。每个子任务应该:
- 描述清晰
- 可以由单个 Agent 完成
- 有明确的完成标准
### 步骤 3: 分配 Agent
为每个子任务选择最合适的执行 Agent。
### 步骤 4: 确定执行顺序
如果有依赖关系,确定正确的执行顺序。
## 输出格式
请以 JSON 格式输出你的决策,包含以下字段:
- analysis: 任务分析
- task_plan: 任务计划数组,每个元素包含 id, description, assigned_agent
- need_aggregation: 是否需要汇总结果
- next_worker: 下一个执行的 Worker 名称 (research/coder/review)
## 注意
- 如果任务很简单,可以只分配给一个 Agent
- 如果任务需要迭代优化,确保有 review 环节
- 考虑任务之间的依赖关系
- 使用 "research"/"coder"/"review" 作为 assigned_agent 的值
"""
# Review Worker System Prompt
REVIEW_SYSTEM_PROMPT = """你是一个代码和结果评审专家Reviewer。你的职责是检查任务执行结果是否符合要求。
## 原始任务
{original_task}
## 当前任务描述
{task_description}
## 执行结果
{execution_result}
## 共享上下文
{context}
## 检查标准
1. 结果是否完整解决了原始任务?
2. 输出格式是否正确?
3. 是否存在明显的错误或遗漏?
4. 代码是否有潜在问题?
5. 是否有安全漏洞或风险?
## 输出格式
请以 JSON 格式输出评审结果:
- passed: true/false是否通过
- issues: 问题数组,每个包含 severity(high/medium/low) 和 description
- suggestions: 改进建议数组
- retryable: true/false是否可以重试
## 注意
- 如果只有轻微问题passed 可以为 true
- 如果有严重问题passed 应为 false
- 判断是否需要重试,而不是立即失败
"""
# Research Worker System Prompt
RESEARCH_SYSTEM_PROMPT = """你是一个信息搜索和调研专家Researcher。你的职责是根据任务要求搜集和整理信息。
## 任务
{task}
## 共享上下文
{context}
## 请执行以下步骤
### 1. 理解任务
明确需要搜集什么信息,信息的用途是什么。
### 2. 搜索信息
使用可用工具搜索相关信息。
### 3. 整理结果
将搜索结果整理成结构化的信息。
## 输出要求
- 提供清晰、结构化的信息整理
- 标注信息来源
- 如果无法完成任务,说明原因
"""
# Coder Worker System Prompt
CODER_SYSTEM_PROMPT = """你是一个代码编写专家Coder。你的职责是根据任务要求编写和修改代码。
## 任务
{task}
## 共享上下文
{context}
## 请执行以下步骤
### 1. 理解需求
明确需要编写什么代码,代码的用途和约束。
### 2. 编写代码
使用合适的编程语言和框架编写代码。
### 3. 代码检查
确保代码语法正确,逻辑合理。
## 输出要求
- 提供完整的、可运行的代码
- 包含必要的注释说明
- 如果需要执行代码,使用代码执行工具
"""
# Aggregator System Prompt
AGGREGATOR_SYSTEM_PROMPT = """你是一个结果汇总专家Aggregator。你的职责是将多个子任务的结果汇总成最终输出。
## 原始任务
{original_task}
## 任务计划
{task_plan}
## 执行结果
{results}
## 共享上下文
{context}
## 请执行以下步骤
### 1. 分析结果
分析每个子任务的执行结果。
### 2. 识别关键信息
从结果中提取关键信息。
### 3. 汇总输出
将所有结果整合成一个连贯的最终输出。
## 输出要求
- 提供清晰、完整的最终结果
- 标注每个部分的来源
- 确保结果解决了原始任务
"""

View File

@@ -1,262 +0,0 @@
"""
Supervisor Agent - 负责任务规划和分发
"""
import json
import re
from typing import Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.messages import HumanMessage, SystemMessage
from .types import AgentState, TaskItem, AgentType, SupervisorDecision
from .prompts import SUPERVISOR_SYSTEM_PROMPT
class SupervisorAgent:
"""Supervisor Agent - 负责任务规划和分发"""
def __init__(
self,
llm: BaseChatModel,
max_iterations: int = 3,
max_tasks: int = 10
):
self.llm = llm
self.max_iterations = max_iterations
self.max_tasks = max_tasks
def create_node(self):
"""创建 LangGraph 节点"""
return self._supervisor_node
async def _supervisor_node(self, state: AgentState) -> dict:
"""Supervisor 节点逻辑"""
# 首次调用:分析任务并生成计划
if not state.get("task_plan"):
decision = await self._plan_tasks(
task=state["original_task"],
progress="这是任务的开始",
context=state.get("shared_context", {})
)
return {
"task_plan": decision.task_plan,
"next_node": decision.next_worker,
"current_task_index": 0,
"shared_context": {
**state.get("shared_context", {}),
"task_analysis": decision.analysis
}
}
# 非首次调用:检查任务状态,决定下一步
current_task_index = state.get("current_task_index", 0)
task_plan = state.get("task_plan", [])
# 获取当前任务
if current_task_index >= len(task_plan):
# 所有任务完成,进入汇总
return {
"next_node": "aggregator",
"shared_context": state.get("shared_context", {})
}
current_task = task_plan[current_task_index]
# 检查当前任务状态
if current_task.status == "completed":
# 当前任务完成,检查是否还有更多任务
if current_task_index + 1 < len(task_plan):
next_index = current_task_index + 1
next_task = task_plan[next_index]
return {
"current_task_index": next_index,
"next_node": next_task.assigned_agent,
"iteration": state.get("iteration", 0),
"shared_context": state.get("shared_context", {})
}
else:
# 所有任务完成,进入汇总
return {
"next_node": "aggregator",
"shared_context": state.get("shared_context", {})
}
elif current_task.status == "failed":
# 任务失败,检查是否超过最大重试
if current_task.retry_count >= self.max_iterations:
# 超过最大重试,进入汇总(标记失败)
return {
"next_node": "aggregator",
"status": "failed",
"shared_context": state.get("shared_context", {})
}
else:
# 重试当前任务
return {
"next_node": current_task.assigned_agent,
"iteration": state.get("iteration", 0) + 1,
"shared_context": state.get("shared_context", {})
}
elif current_task.status == "needs_retry":
# 需要重试(来自 review
return {
"next_node": current_task.assigned_agent,
"iteration": state.get("iteration", 0) + 1,
"shared_context": state.get("shared_context", {})
}
# 默认继续执行
return {
"next_node": state.get("next_node", "aggregator"),
"shared_context": state.get("shared_context", {})
}
async def _plan_tasks(self, task: str, progress: str, context: dict) -> SupervisorDecision:
"""调用 LLM 生成任务计划"""
# 格式化 prompt
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else ""
prompt = SUPERVISOR_SYSTEM_PROMPT.format(
task=task,
progress=progress,
context=context_str
)
# 调用 LLM
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请分析任务并制定执行计划。")
])
# 解析 LLM 输出
decision = self._parse_response(response.content, task)
return decision
def _parse_response(self, response: str, original_task: str) -> SupervisorDecision:
"""解析 LLM 响应为结构化决策"""
try:
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
data = json.loads(json_match.group())
else:
raise ValueError("No JSON found")
# 解析任务计划
task_plan = []
for i, item in enumerate(data.get("task_plan", [])):
task = TaskItem(
id=item.get("id", f"task_{i+1}"),
description=item.get("description", ""),
assigned_agent=AgentType(item.get("assigned_agent", "coder")),
status="pending"
)
task_plan.append(task)
# 确定下一个 Worker
next_worker = data.get("next_worker", "research")
if isinstance(next_worker, dict):
next_worker = next_worker.get("assigned_agent", "research")
return SupervisorDecision(
analysis=data.get("analysis", "任务分析"),
task_plan=task_plan,
need_aggregation=data.get("need_aggregation", True),
next_worker=AgentType(next_worker)
)
except Exception as e:
# 解析失败,创建默认计划
return self._create_default_plan(original_task)
def _create_default_plan(self, task: str) -> SupervisorDecision:
"""创建默认任务计划"""
task_lower = task.lower()
# 根据任务关键词判断
if any(keyword in task_lower for keyword in ["搜索", "查找", "调研", "研究", "research", "search"]):
assigned_agent = AgentType.RESEARCH
elif any(keyword in task_lower for keyword in ["代码", "", "开发", "code", "program", "写代码"]):
assigned_agent = AgentType.CODER
else:
assigned_agent = AgentType.CODER
# 创建默认任务
task_item = TaskItem(
id="task_1",
description=task,
assigned_agent=assigned_agent,
status="pending"
)
return SupervisorDecision(
analysis="简单任务,直接分配给合适的 Agent 执行",
task_plan=[task_item],
need_aggregation=True,
next_worker=assigned_agent
)
class ResultAggregator:
"""结果聚合器 - 汇总多个任务的结果"""
def __init__(self, llm: BaseChatModel):
self.llm = llm
def create_node(self):
"""创建 LangGraph 节点"""
return self._aggregate_node
async def _aggregate_node(self, state: AgentState) -> dict:
"""聚合节点逻辑"""
# 准备任务计划和结果
task_plan = state.get("task_plan", [])
results = state.get("results", {})
original_task = state.get("original_task", "")
# 构建任务描述
task_descriptions = []
for task in task_plan:
task_descriptions.append(f"- {task.id}: {task.description} -> {task.status}")
# 构建结果描述
result_items = []
for task_id, result in results.items():
if isinstance(result, dict):
content = result.get("content", str(result))
else:
content = str(result)
result_items.append(f"## {task_id}\n{content}")
# 调用 LLM 汇总结果
from .prompts import AGGREGATOR_SYSTEM_PROMPT
context_str = json.dumps(state.get("shared_context", {}), ensure_ascii=False, indent=2)
prompt = AGGREGATOR_SYSTEM_PROMPT.format(
original_task=original_task,
task_plan="\n".join(task_descriptions),
results="\n\n".join(result_items) if result_items else "无结果",
context=context_str
)
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请汇总以上结果,给出最终输出。")
])
# 检查是否有失败的任务
has_failed = any(task.status == "failed" for task in task_plan)
return {
"final_output": response.content,
"status": "failed" if has_failed else "completed",
"next_node": "__end__"
}

View File

@@ -1,93 +0,0 @@
"""
多智能体系统数据类型定义
"""
from typing import TypedDict, Annotated, Optional, Literal
from operator import add
from pydantic import BaseModel, Field
from enum import Enum
class TaskStatus(str, Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
NEEDS_RETRY = "needs_retry"
class AgentType(str, Enum):
"""Agent 类型"""
SUPERVISOR = "supervisor"
RESEARCH = "research"
CODER = "coder"
REVIEW = "review"
AGGREGATOR = "aggregator"
class TaskItem(BaseModel):
"""单个任务项"""
id: str = Field(..., description="任务唯一标识")
description: str = Field(..., description="任务描述")
assigned_agent: AgentType = Field(..., description="分配的 Agent 类型")
status: TaskStatus = Field(default=TaskStatus.PENDING, description="任务状态")
result: Optional[dict] = Field(default=None, description="任务执行结果")
error: Optional[str] = Field(default=None, description="错误信息")
retry_count: int = Field(default=0, description="重试次数")
class SupervisorDecision(BaseModel):
"""Supervisor 的结构化决策"""
analysis: str = Field(..., description="任务分析")
task_plan: list[TaskItem] = Field(..., description="任务计划")
need_aggregation: bool = Field(default=True, description="是否需要汇总")
next_worker: AgentType = Field(..., description="下一个执行的 Worker")
class ReviewResult(BaseModel):
"""Review 结果"""
passed: bool = Field(..., description="是否通过")
issues: list[dict] = Field(default_factory=list, description="问题列表")
suggestions: list[str] = Field(default_factory=list, description="改进建议")
retryable: bool = Field(default=True, description="是否可重试")
class AgentState(TypedDict):
"""贯穿整个图的 Agent 状态"""
# 用户输入
original_task: str # 原始任务描述
session_id: Optional[str] # 会话 ID
# 任务规划
task_plan: list[TaskItem] # 分解后的任务列表
current_task_index: int # 当前执行的任务索引
# 执行结果
results: dict # {task_id: result}
# 流程控制
iteration: int # 当前迭代次数
next_node: str # 下一个节点名称
# 共享上下文
shared_context: dict # Agent 间共享的数据
# 最终输出
final_output: str
status: Literal["running", "completed", "failed"] # 运行状态
def create_initial_state(task: str, session_id: str = None) -> AgentState:
"""创建初始状态"""
return {
"original_task": task,
"session_id": session_id,
"task_plan": [],
"current_task_index": 0,
"results": {},
"iteration": 0,
"next_node": "supervisor",
"shared_context": {},
"final_output": "",
"status": "running"
}

View File

@@ -1,14 +0,0 @@
"""
Worker Agents
"""
from .base import BaseWorker
from .research import ResearchWorker
from .coder import CoderWorker
from .review import ReviewWorker
__all__ = [
"BaseWorker",
"ResearchWorker",
"CoderWorker",
"ReviewWorker",
]

View File

@@ -1,138 +0,0 @@
"""
Worker Agent 基类
"""
import json
from abc import ABC, abstractmethod
from typing import Any, Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from ..types import AgentState, TaskItem, TaskStatus
class BaseWorker(ABC):
"""Worker Agent 基类"""
def __init__(
self,
llm: BaseChatModel,
name: str,
system_prompt: str,
tools: list = None,
tool_registry=None
):
self.llm = llm
self.name = name
self.system_prompt = system_prompt
self.tools = tools or []
self.tool_registry = tool_registry
@abstractmethod
async def execute(self, task: TaskItem, context: dict) -> dict:
"""
执行任务
Returns:
dict: {
"success": bool,
"content": str,
"context": dict, # 更新共享上下文
"error": str (optional)
}
"""
pass
def create_node(self):
"""创建 LangGraph 节点"""
async def node(state: AgentState) -> dict:
task_index = state.get("current_task_index", 0)
task_plan = state.get("task_plan", [])
if task_index >= len(task_plan):
return {"next_node": "aggregator"}
task = task_plan[task_index]
shared_context = state.get("shared_context", {})
# 更新任务状态为 running
updated_plan = self._update_task_status(task_plan, task.id, TaskStatus.RUNNING)
try:
# 执行任务
result = await self.execute(task, shared_context)
# 更新任务状态
if result.get("success"):
updated_plan = self._update_task_status(
updated_plan,
task.id,
TaskStatus.COMPLETED,
result=result.get("content", "")
)
else:
updated_plan = self._update_task_status(
updated_plan,
task.id,
TaskStatus.FAILED,
error=result.get("error", "Unknown error")
)
# 构建新上下文
new_context = {**shared_context, **(result.get("context", {}))}
return {
"task_plan": updated_plan,
"results": {**state.get("results", {}), task.id: result},
"shared_context": new_context,
"next_node": "review"
}
except Exception as e:
# 执行出错
updated_plan = self._update_task_status(
updated_plan,
task.id,
TaskStatus.FAILED,
error=str(e)
)
return {
"task_plan": updated_plan,
"results": {**state.get("results", {}), task.id: {"error": str(e)}},
"next_node": "review"
}
return node
def _update_task_status(
self,
tasks: list,
task_id: str,
status: TaskStatus,
result: Any = None,
error: str = None
) -> list:
"""更新任务状态"""
return [
{
**task.model_dump() if hasattr(task, 'model_dump') else task,
"status": status,
"result": result,
"error": error
}
for task in tasks
]
def _build_messages(self, task: str, context: dict) -> list:
"""构建消息列表"""
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else ""
user_prompt = self.system_prompt.format(
task=task,
context=context_str
)
return [
SystemMessage(content=user_prompt),
HumanMessage(content=task)
]

View File

@@ -1,146 +0,0 @@
"""
Coder Worker - 代码编写和修改
"""
import json
from langchain_core.language_models import BaseChatModel
from .base import BaseWorker
from ..types import TaskItem
from ..prompts import CODER_SYSTEM_PROMPT
class CoderWorker(BaseWorker):
"""Coder Worker - 代码编写和修改"""
def __init__(
self,
llm: BaseChatModel,
tool_registry=None,
tools: list = None
):
super().__init__(
llm=llm,
name="coder",
system_prompt=CODER_SYSTEM_PROMPT,
tools=tools or [],
tool_registry=tool_registry
)
async def execute(self, task: TaskItem, context: dict) -> dict:
"""执行编码任务"""
# 构建消息
messages = self._build_messages(task.description, context)
# 如果有代码执行工具,启用它
if self.tool_registry:
tool_defs = self._get_available_tools()
if tool_defs:
try:
response = await self.llm.agenerate(
messages=messages,
tools=tool_defs
)
return self._handle_tool_response(response, messages)
except Exception:
# 如果工具调用失败,回退到普通调用
pass
# 普通调用
try:
response = await self.llm.ainvoke(messages)
content = response.content if hasattr(response, 'content') else str(response)
return {
"success": True,
"content": content,
"context": {
"code_written": True,
"last_coder": self.name
}
}
except Exception as e:
return {
"success": False,
"content": "",
"error": str(e),
"context": {}
}
def _get_available_tools(self) -> list:
"""获取可用工具定义"""
if not self.tool_registry:
return []
tool_names = self.tools or ["search", "execute_code"]
tool_defs = []
for tool_name in tool_names:
tool_def = self.tool_registry.get_tool_definition(tool_name)
if tool_def:
tool_defs.append(tool_def)
return tool_defs
def _handle_tool_response(self, response, original_messages: list) -> dict:
"""处理工具调用响应"""
# 简化实现
response_message = response.generations[0][0]
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
# 有工具调用
tool_results = []
for tool_call in response_message.tool_calls:
tool_name = tool_call.name
tool_args = tool_call.arguments
# 执行工具
try:
tool_func, _ = self.tool_registry.get_tool(tool_name)
result = tool_func(**tool_args)
tool_results.append({
"tool": tool_name,
"result": str(result)
})
except Exception as e:
tool_results.append({
"tool": tool_name,
"error": str(e)
})
# 将工具结果添加到消息
for msg in response.generations[0]:
original_messages.append(msg)
for tool_result in tool_results:
original_messages.append({
"role": "tool",
"content": json.dumps(tool_result, ensure_ascii=False)
})
# 再次调用 LLM 生成最终响应
final_response = await self.llm.ainvoke(original_messages)
content = final_response.content if hasattr(final_response, 'content') else str(final_response)
return {
"success": True,
"content": content,
"context": {
"code_written": True,
"tool_results": tool_results,
"last_coder": self.name
}
}
else:
# 无工具调用
content = response_message.text if hasattr(response_message, 'text') else str(response_message)
return {
"success": True,
"content": content,
"context": {
"code_written": True,
"last_coder": self.name
}
}

View File

@@ -1,70 +0,0 @@
"""
Research Worker - 信息搜索和调研
"""
import json
from langchain_core.language_models import BaseChatModel
from .base import BaseWorker
from ..types import TaskItem
from ..prompts import RESEARCH_SYSTEM_PROMPT
class ResearchWorker(BaseWorker):
"""Research Worker - 信息搜索和调研"""
def __init__(
self,
llm: BaseChatModel,
tool_registry=None,
tools: list = None
):
super().__init__(
llm=llm,
name="research",
system_prompt=RESEARCH_SYSTEM_PROMPT,
tools=tools or [],
tool_registry=tool_registry
)
async def execute(self, task: TaskItem, context: dict) -> dict:
"""执行调研任务"""
# 构建消息
messages = self._build_messages(task.description, context)
try:
# 调用 LLM
response = await self.llm.ainvoke(messages)
content = response.content if hasattr(response, 'content') else str(response)
# 尝试提取搜索结果
search_results = self._extract_search_results(content)
return {
"success": True,
"content": content,
"context": {
"research_results": search_results,
"last_research_by": self.name
}
}
except Exception as e:
return {
"success": False,
"content": "",
"error": str(e),
"context": {}
}
def _extract_search_results(self, content: str) -> list:
"""从内容中提取搜索结果"""
# 简单实现:查找以 - 或 * 开头的行
results = []
for line in content.split('\n'):
line = line.strip()
if line.startswith(('- ', '* ', '1. ', '2. ', '3. ')):
results.append(line.lstrip('-*123. '))
return results[:10] # 限制数量

View File

@@ -1,174 +0,0 @@
"""
Review Worker - 结果检查和质量评审
"""
import json
import re
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from .base import BaseWorker
from ..types import AgentState, TaskItem, TaskStatus, ReviewResult
from ..prompts import REVIEW_SYSTEM_PROMPT
class ReviewWorker(BaseWorker):
"""Review Worker - 结果检查和质量评审"""
def __init__(
self,
llm: BaseChatModel,
tool_registry=None,
tools: list = None
):
super().__init__(
llm=llm,
name="review",
system_prompt=REVIEW_SYSTEM_PROMPT,
tools=tools or [],
tool_registry=tool_registry
)
async def execute(self, task: TaskItem, context: dict) -> dict:
"""执行评审任务"""
# 获取当前任务索引和任务计划
# 注意:这里需要从 context 中获取更多信息
# 构建 prompt
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else ""
prompt = REVIEW_SYSTEM_PROMPT.format(
original_task=context.get("original_task", ""),
task_description=task.description,
execution_result=task.result if task.result else "无结果",
context=context_str
)
try:
# 调用 LLM 进行评审
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请评审以上执行结果。")
])
# 解析评审结果
review_result = self._parse_review_response(response.content)
# 根据评审结果决定下一步
if review_result.passed:
# 通过,更新任务状态为 completed
new_status = TaskStatus.COMPLETED
next_node = "supervisor" # 返回 Supervisor 继续执行
else:
# 未通过,检查是否可重试
if review_result.retryable:
new_status = TaskStatus.NEEDS_RETRY
next_node = "supervisor" # 返回 Supervisor 决定是否重试
else:
new_status = TaskStatus.FAILED
next_node = "aggregator" # 失败,进入汇总
return {
"success": review_result.passed,
"content": response.content,
"review_result": review_result.model_dump() if hasattr(review_result, 'model_dump') else dict(review_result),
"context": {
"review_passed": review_result.passed,
"issues": review_result.issues,
"last_review_by": self.name
}
}
except Exception as e:
return {
"success": False,
"content": "",
"error": str(e),
"context": {}
}
def _parse_review_response(self, response: str) -> ReviewResult:
"""解析评审响应"""
try:
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
data = json.loads(json_match.group())
else:
raise ValueError("No JSON found")
return ReviewResult(
passed=data.get("passed", True),
issues=data.get("issues", []),
suggestions=data.get("suggestions", []),
retryable=data.get("retryable", True)
)
except Exception:
# 解析失败,默认通过
return ReviewResult(
passed=True,
issues=[],
suggestions=[],
retryable=True
)
def create_node(self):
"""创建 LangGraph 节点"""
async def node(state: AgentState) -> dict:
task_index = state.get("current_task_index", 0)
task_plan = state.get("task_plan", [])
if task_index >= len(task_plan):
return {"next_node": "aggregator"}
task = task_plan[task_index]
shared_context = {
**state.get("shared_context", {}),
"original_task": state.get("original_task", "")
}
try:
# 执行评审
result = await self.execute(task, shared_context)
# 更新任务状态
review_passed = result.get("review_result", {}).get("passed", True)
retryable = result.get("review_result", {}).get("retryable", True)
if review_passed:
updated_status = TaskStatus.COMPLETED
elif retryable:
updated_status = TaskStatus.NEEDS_RETRY
else:
updated_status = TaskStatus.FAILED
updated_plan = self._update_task_status(
task_plan,
task.id,
updated_status,
result=task.result
)
# 确定下一步
if updated_status == TaskStatus.COMPLETED:
next_node = "supervisor"
elif updated_status == TaskStatus.NEEDS_RETRY:
next_node = "supervisor"
else:
next_node = "aggregator"
return {
"task_plan": updated_plan,
"results": {**state.get("results", {}), f"{task.id}_review": result},
"shared_context": {**shared_context, **result.get("context", {})},
"next_node": next_node
}
except Exception as e:
return {
"next_node": "aggregator",
"results": {**state.get("results", {}), f"{task.id}_review": {"error": str(e)}}
}
return node

View File

@@ -1,22 +0,0 @@
"""
工具实现模块
"""
# 基础工具
from . import search
from . import calculator
from . import time_tool
# 安全工具
from . import sandbox
from . import database
from . import api_client
__all__ = [
"search",
"calculator",
"time_tool",
"sandbox",
"database",
"api_client",
]

View File

@@ -1,166 +0,0 @@
"""
API 调用工具 - 安全的外部 API 调用
"""
import httpx
from typing import Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
class APIPermission(Enum):
"""API 权限级别"""
PUBLIC = "public" # 公开 API
APPROVED = "approved" # 已审批的 API
ADMIN = "admin" # 管理员 API
@dataclass
class APIEndpoint:
"""API 端点定义"""
name: str
url: str
method: str
permission: APIPermission
description: str
rate_limit: int = 60 # 每分钟请求次数
# API 白名单
ALLOWED_APIS = [
APIEndpoint(
name="weather",
url="https://api.weather.example.com/v1",
method="GET",
permission=APIPermission.PUBLIC,
description="获取天气信息",
rate_limit=30
),
APIEndpoint(
name="news",
url="https://newsapi.org/v2",
method="GET",
permission=APIPermission.PUBLIC,
description="获取新闻",
rate_limit=30
),
# 可以添加更多已审批的 API
]
class APICallTool:
"""
API 调用工具
安全特性:
- 只允许调用白名单中的 API
- 速率限制
- 请求超时
- 响应大小限制
"""
def __init__(self):
self.allowed_apis = {api.name: api for api in ALLOWED_APIS}
self.request_timeout = 10 # 请求超时(秒)
self.max_response_size = 1024 * 1024 # 最大响应大小1MB
async def call(
self,
api_name: str,
endpoint: str = "",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
调用 API
Args:
api_name: API 名称(必须在白名单中)
endpoint: 具体的端点
params: 查询参数
headers: 请求头
Returns:
API 响应
"""
# 安全检查1: API 必须在白名单中
if api_name not in self.allowed_apis:
return {
"success": False,
"error": f"API '{api_name}' not in whitelist. Allowed: {list(self.allowed_apis.keys())}"
}
api = self.allowed_apis[api_name]
# 构建完整 URL
url = f"{api.url}/{endpoint}" if endpoint else api.url
try:
async with httpx.AsyncClient(timeout=self.request_timeout) as client:
# 根据方法调用
if api.method == "GET":
response = await client.get(url, params=params, headers=headers)
elif api.method == "POST":
response = await client.post(url, json=params, headers=headers)
else:
return {
"success": False,
"error": f"Method {api.method} not supported"
}
# 检查响应大小
if len(response.content) > self.max_response_size:
return {
"success": False,
"error": f"Response too large (max {self.max_response_size} bytes)"
}
return {
"success": True,
"status_code": response.status_code,
"data": response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text,
"headers": dict(response.headers)
}
except httpx.TimeoutException:
return {
"success": False,
"error": "Request timeout"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
def list_apis(self) -> list:
"""列出所有可用的 API"""
return [
{
"name": api.name,
"description": api.description,
"method": api.method,
"permission": api.permission.value,
"rate_limit": api.rate_limit
}
for api in ALLOWED_APIS
]
# 全局实例
api_tool = APICallTool()
async def call_api(
api_name: str,
endpoint: str = "",
params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
API 调用工具(供 Agent 使用)
"""
return await api_tool.call(api_name, endpoint, params)
def list_allowed_apis() -> list:
"""列出允许的 API"""
return api_tool.list_apis()

View File

@@ -1,91 +0,0 @@
"""
计算器工具
"""
import ast
import operator
from typing import Any
# 安全运算符
SAFE_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.USub: operator.neg,
}
def safe_eval_expr(node):
"""安全地求值表达式节点"""
if isinstance(node, ast.Num):
return node.n
elif isinstance(node, ast.BinOp):
left = safe_eval_expr(node.left)
right = safe_eval_expr(node.right)
op_type = type(node.op)
if op_type in SAFE_OPERATORS:
return SAFE_OPERATORS[op_type](left, right)
raise ValueError(f"Unsupported operator: {op_type}")
elif isinstance(node, ast.UnaryOp):
operand = safe_eval_expr(node.operand)
op_type = type(node.op)
if op_type in SAFE_OPERATORS:
return SAFE_OPERATORS[op_type](operand)
raise ValueError(f"Unsupported unary operator: {op_type}")
else:
raise ValueError(f"Unsupported expression: {ast.dump(node)}")
def calculate(expression: str) -> dict:
"""
执行数学计算
Args:
expression: 数学表达式,如 "2 + 2""sqrt(16)"
Returns:
计算结果
"""
try:
# 预处理:处理常见数学函数
expression = expression.replace("sqrt", "**0.5")
expression = expression.replace("pi", "3.14159265359")
expression = expression.replace("e", "2.71828182846")
# 解析表达式
tree = ast.parse(expression, mode='eval')
result = safe_eval_expr(tree.body)
return {
"success": True,
"expression": expression,
"result": result,
"type": type(result).__name__
}
except Exception as e:
return {
"success": False,
"expression": expression,
"error": str(e)
}
# 工具定义
TOOL_DEFINITION = {
"name": "calculator",
"description": "Perform mathematical calculations. Supports basic arithmetic (+, -, *, /), powers (**), and functions (sqrt).",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate, e.g., '2 + 2' or 'sqrt(16) + 5'"
}
},
"required": ["expression"]
}
}

View File

@@ -1,96 +0,0 @@
"""
数据库查询工具 - 安全的数据查询接口
"""
from typing import Dict, Any, List, Optional
import os
# 只读查询白名单 - 只允许 SELECT 语句
ALLOWED_TABLES = ["users", "agents", "sessions", "audit_logs"]
class DatabaseQueryTool:
"""
数据库查询工具
安全特性:
- 只允许 SELECT 查询
- 表名白名单
- 结果数量限制
"""
def __init__(self, connection_string: str = ""):
self.connection_string = connection_string or os.getenv(
"DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/x_agents"
)
self.max_rows = 100 # 最多返回100行
def query(self, sql: str, params: List[Any] = None) -> Dict[str, Any]:
"""
执行查询
Args:
sql: SQL 查询语句(必须是 SELECT
params: 查询参数
Returns:
查询结果
"""
# 安全检查1: 必须是 SELECT 语句
sql_upper = sql.strip().upper()
if not sql_upper.startswith("SELECT"):
return {
"success": False,
"error": "Only SELECT queries are allowed"
}
# 安全检查2: 禁止危险关键字
dangerous_keywords = [
"DROP", "DELETE", "INSERT", "UPDATE", "ALTER",
"CREATE", "TRUNCATE", "EXEC", "EXECUTE"
]
for keyword in dangerous_keywords:
if keyword in sql_upper:
return {
"success": False,
"error": f"Keyword '{keyword}' is not allowed"
}
# 安全检查3: 表名白名单
for table in ALLOWED_TABLES:
if f"FROM {table}" in sql_upper or f"JOIN {table}" in sql_upper:
# 表名在白名单中,允许
break
else:
# 没有找到白名单表
return {
"success": False,
"error": f"Table not in whitelist. Allowed: {ALLOWED_TABLES}"
}
# TODO: 实际执行查询(需要数据库连接)
# 这里返回模拟数据
return {
"success": True,
"message": "Query executed (mock mode - database not connected)",
"rows": [],
"columns": []
}
# 全局实例
db_tool = DatabaseQueryTool()
def query_data(sql: str) -> Dict[str, Any]:
"""
查询数据工具
Args:
sql: SELECT 查询语句
Returns:
查询结果
"""
return db_tool.query(sql)

View File

@@ -1,87 +0,0 @@
"""
网页搜索工具
"""
import httpx
from typing import Optional
async def search_web(query: str, max_results: int = 5) -> dict:
"""
搜索网页获取信息
Args:
query: 搜索关键词
max_results: 返回结果数量
Returns:
搜索结果
"""
# 这里可以使用搜索引擎API如 Google, Bing, DuckDuckGo 等
# 示例使用 DuckDuckGo API免费
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.duckduckgo.com/",
params={
"q": query,
"format": "json",
"no_html": 1,
"skip_disambig": 1
},
timeout=10.0
)
if response.status_code == 200:
data = response.json()
results = []
# 提取相关主题
if "RelatedTopics" in data:
for item in data["RelatedTopics"][:max_results]:
if "Text" in item:
results.append({
"title": item.get("Text", "").split(" - ")[0] if " - " in item.get("Text", "") else "",
"content": item.get("Text", ""),
"url": item.get("URL", "")
})
return {
"success": True,
"query": query,
"results": results,
"count": len(results)
}
else:
return {
"success": False,
"error": f"Search API returned status {response.status_code}"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# 工具定义(用于 LLM
TOOL_DEFINITION = {
"name": "search",
"description": "Search the web for information. Use this when you need to find current information or facts.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 5
}
},
"required": ["query"]
}
}

View File

@@ -1,70 +0,0 @@
"""
时间工具
"""
from datetime import datetime
from typing import Optional
def get_current_time(timezone: Optional[str] = None) -> dict:
"""
获取当前时间
Args:
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
Returns:
当前时间信息
"""
now = datetime.now()
return {
"success": True,
"datetime": now.isoformat(),
"timestamp": now.timestamp(),
"date": now.strftime("%Y-%m-%d"),
"time": now.strftime("%H:%M:%S"),
"weekday": now.strftime("%A"),
"timezone": timezone or "Local Time"
}
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> dict:
"""
格式化时间戳
Args:
timestamp: Unix 时间戳
format_str: 格式字符串
Returns:
格式化后的时间
"""
try:
dt = datetime.fromtimestamp(timestamp)
return {
"success": True,
"formatted": dt.strftime(format_str),
"datetime": dt.isoformat()
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# 工具定义
TOOL_DEFINITION = {
"name": "get_current_time",
"description": "Get the current date and time. Useful for timestamps or scheduling.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
"default": "Local"
}
}
}
}

View File

@@ -1,99 +0,0 @@
"""
工具注册表 - 管理所有可用工具(白名单机制)
"""
from typing import Any, Callable, Optional
from dataclasses import dataclass, asdict
from enum import Enum
class SecurityLevel(Enum):
"""工具安全等级"""
SAFE = "safe" # 安全操作
REVIEW = "review" # 需要审核
DANGER = "danger" # 危险操作
@dataclass
class ToolMetadata:
"""工具元数据"""
name: str
description: str
security_level: str
require_approval: bool = False
allowed_roles: list = None
def dict(self):
return {
"name": self.name,
"description": self.description,
"security_level": self.security_level,
"require_approval": self.require_approval
}
class ToolRegistry:
"""工具注册表"""
def __init__(self):
self._tools: dict[str, tuple[Callable, ToolMetadata]] = {}
self._definitions: dict[str, dict] = {}
def register(
self,
name: str,
func: Callable,
description: str = "",
security_level: str = "safe",
require_approval: bool = False,
allowed_roles: list = None,
parameters: dict = None
):
"""注册工具到白名单"""
metadata = ToolMetadata(
name=name,
description=description,
security_level=security_level,
require_approval=require_approval,
allowed_roles=allowed_roles or ["user", "admin"]
)
self._tools[name] = (func, metadata)
# 生成工具定义(用于 LLM 调用)
self._definitions[name] = {
"name": name,
"description": description,
"parameters": parameters or {
"type": "object",
"properties": {},
"required": []
}
}
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
"""获取工具函数和元数据"""
if name not in self._tools:
raise ValueError(f"Tool '{name}' not found in whitelist")
return self._tools[name]
def get_tool_definition(self, name: str) -> Optional[dict]:
"""获取工具定义(用于 LLM"""
return self._definitions.get(name)
def list_tools(self) -> list[ToolMetadata]:
"""列出所有已注册工具"""
return [meta for _, meta in self._tools.values()]
def check_permission(self, tool_name: str, user_role: str) -> bool:
"""检查用户权限"""
if tool_name not in self._tools:
return False
_, metadata = self._tools[tool_name]
return user_role in metadata.allowed_roles
def need_approval(self, tool_name: str) -> bool:
"""判断是否需要审批"""
if tool_name not in self._tools:
return False
_, metadata = self._tools[tool_name]
return metadata.require_approval

View File

@@ -1,283 +0,0 @@
"""
沙盒执行环境 - 在项目内构建,不依赖 Docker
提供安全的代码执行环境
"""
import subprocess
import tempfile
import os
import shutil
import resource
import signal
import threading
from typing import Dict, Any, Optional
from pathlib import Path
class SandboxConfig:
"""沙盒配置"""
# 资源限制
MAX_MEMORY_MB = 256 # 最大内存 (MB)
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
class Sandbox:
"""
沙盒执行器 - 使用 subprocess 隔离执行
安全特性:
- 内存限制
- CPU限制
- 超时控制
- 网络隔离(可选)
- 临时文件隔离
"""
def __init__(self, config: Optional[SandboxConfig] = None):
self.config = config or SandboxConfig()
self.temp_dir = None
def _setup_temp_dir(self) -> str:
"""创建临时目录"""
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
return self.temp_dir
def _cleanup(self):
"""清理临时目录"""
if self.temp_dir and os.path.exists(self.temp_dir):
try:
shutil.rmtree(self.temp_dir)
except Exception as e:
print(f"Cleanup error: {e}")
def execute(
self,
code: str,
language: str = "python",
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""
在沙盒中执行代码
Args:
code: 要执行的代码
language: 语言类型 (python, javascript)
timeout: 超时时间(秒)
Returns:
执行结果
"""
timeout = timeout or self.config.MAX_EXECUTION_TIME
self._setup_temp_dir()
try:
if language == "python":
return self._execute_python(code, timeout)
elif language == "javascript":
return self._execute_javascript(code, timeout)
else:
return {
"success": False,
"error": f"Unsupported language: {language}"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
finally:
self._cleanup()
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
"""执行 Python 代码"""
# 创建临时文件
temp_file = os.path.join(self.temp_dir, "code.py")
with open(temp_file, "w", encoding="utf-8") as f:
f.write(code)
try:
# 构建命令
cmd = ["python", temp_file]
# 执行
result = subprocess.run(
cmd,
capture_output=True,
timeout=timeout,
cwd=self.temp_dir, # 限制工作目录
env=self._get_restricted_env(), # 限制环境变量
)
# 检查输出大小
stdout = result.stdout.decode("utf-8", errors="replace")
stderr = result.stderr.decode("utf-8", errors="replace")
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
return {
"success": result.returncode == 0,
"output": stdout,
"error": stderr if result.returncode != 0 else None,
"exit_code": result.returncode
}
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timeout ({timeout}s)",
"output": None
}
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
"""执行 JavaScript 代码"""
temp_file = os.path.join(self.temp_dir, "code.js")
with open(temp_file, "w", encoding="utf-8") as f:
f.write(code)
try:
# 尝试使用 node
cmd = ["node", temp_file]
result = subprocess.run(
cmd,
capture_output=True,
timeout=timeout,
cwd=self.temp_dir,
)
stdout = result.stdout.decode("utf-8", errors="replace")
stderr = result.stderr.decode("utf-8", errors="replace")
return {
"success": result.returncode == 0,
"output": stdout,
"error": stderr if result.returncode != 0 else None,
"exit_code": result.returncode
}
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timeout ({timeout}s)",
"output": None
}
except FileNotFoundError:
return {
"success": False,
"error": "Node.js not installed",
"output": None
}
def _get_restricted_env(self) -> Dict[str, str]:
"""
获取受限的环境变量
移除敏感变量,保留必要的 PATH
"""
# 保留 PATH移除其他敏感变量
safe_env = {
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
"LANG": "en_US.UTF-8",
"HOME": self.temp_dir,
"TMPDIR": self.temp_dir,
}
# 移除可能不安全的变量
unsafe_vars = [
"PYTHONPATH",
"PYTHONHOME",
"LD_PRELOAD",
"LD_LIBRARY_PATH",
]
for var in unsafe_vars:
if var in os.environ:
del os.environ[var]
return safe_env
class SafeEval:
"""
安全求值器 - 用于简单表达式计算
比沙盒更轻量,适用于不需要完全隔离的场景
"""
# 安全函数白名单
SAFE_BUILTINS = {
"abs": abs,
"min": min,
"max": max,
"sum": sum,
"len": len,
"round": round,
"pow": pow,
"print": print,
"str": str,
"int": int,
"float": float,
"bool": bool,
"list": list,
"dict": dict,
"tuple": tuple,
"set": set,
"range": range,
"enumerate": enumerate,
"zip": zip,
"map": map,
"filter": filter,
"sorted": sorted,
"reversed": reversed,
}
# 安全数学常量
SAFE_MATH = {
"pi": 3.14159265359,
"e": 2.71828182846,
}
@classmethod
def eval(cls, expression: str) -> Any:
"""
安全地求值表达式
Args:
expression: 数学表达式
Returns:
计算结果
"""
# 预处理表达式
expression = expression.replace("sqrt", "**0.5")
# 构建安全命名空间
safe_namespace = {
**cls.SAFE_BUILTINS,
**cls.SAFE_MATH,
"__builtins__": {} # 禁用__builtins__
}
try:
result = eval(expression, safe_namespace)
return result
except Exception as e:
raise ValueError(f"Evaluation error: {e}")
# 全局沙盒实例
sandbox = Sandbox()
# 装饰器:快速将函数封装为沙盒执行
def sandboxed(timeout: int = 30):
"""装饰器:为函数添加沙盒执行能力"""
def decorator(func):
def wrapper(code: str, *args, **kwargs):
result = sandbox.execute(code, timeout=timeout)
if not result["success"]:
raise RuntimeError(result.get("error", "Execution failed"))
return result["output"]
return wrapper
return decorator

View File

@@ -1,149 +0,0 @@
"""
API 路由定义
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from app.agent.core.agent import AgentManager
from app.security.approval import ApprovalService
router = APIRouter()
# 全局依赖(实际应该注入)
_agent_manager: Optional[AgentManager] = None
_approval_service: Optional[ApprovalService] = None
def get_agent_manager() -> AgentManager:
"""获取 Agent 管理器"""
# 这里应该从 app.state 获取
from app.main import agent_manager
if agent_manager is None:
raise HTTPException(status_code=503, detail="Agent service not initialized")
return agent_manager
def get_approval_service() -> ApprovalService:
"""获取审批服务"""
global _approval_service
if _approval_service is None:
_approval_service = ApprovalService()
return _approval_service
# ==================== 请求/响应模型 ====================
class ChatRequest(BaseModel):
"""聊天请求"""
agent_id: str
message: str
session_id: str = ""
context: dict = {}
class ChatResponse(BaseModel):
"""聊天响应"""
reply: str
session_id: str
tools_used: list[str] = []
metadata: dict = {}
class ApprovalRequest(BaseModel):
"""审批请求"""
request_id: str
tool_name: str
params: dict
reason: str
approved: bool
# ==================== API 端点 ====================
@router.post("/chat", response_model=ChatResponse)
async def chat(
request: ChatRequest,
agent_manager: AgentManager = Depends(get_agent_manager)
):
"""处理 Agent 聊天请求"""
try:
# 生成会话ID
if not request.session_id:
import uuid
request.session_id = str(uuid.uuid4())
# 执行 Agent
result = await agent_manager.execute(
agent_id=request.agent_id,
message=request.message,
session_id=request.session_id,
context=request.context
)
return ChatResponse(
reply=result.get("reply", ""),
session_id=request.session_id,
tools_used=result.get("tools_used", []),
metadata=result.get("metadata", {})
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}")
@router.post("/tool/request")
async def request_tool_execution(
request: dict,
approval_service: ApprovalService = Depends(get_approval_service)
):
"""请求执行工具(需要审批)"""
tool_name = request.get("tool_name")
params = request.get("params", {})
user_id = request.get("user_id", "unknown")
agent_id = request.get("agent_id")
reason = request.get("reason", "")
# 创建审批请求
request_id = await approval_service.request_approval(
tool_name=tool_name,
params=params,
user_id=user_id,
agent_id=agent_id or "",
reason=reason
)
return {
"request_id": request_id,
"status": "pending"
}
@router.get("/tools")
async def list_tools(agent_manager: AgentManager = Depends(get_agent_manager)):
"""列出所有可用工具"""
tools = agent_manager.list_tools()
return {"tools": [tool.dict() for tool in tools]}
@router.get("/agents")
async def list_agents(agent_manager: AgentManager = Depends(get_agent_manager)):
"""列出所有已加载的 Agent"""
agents = agent_manager.list_agents()
return {"agents": agents}
@router.get("/agent/{agent_id}")
async def get_agent(
agent_id: str,
agent_manager: AgentManager = Depends(get_agent_manager)
):
"""获取特定 Agent 信息"""
agent_info = agent_manager.get_agent_info(agent_id)
if not agent_info:
raise HTTPException(status_code=404, detail="Agent not found")
return agent_info

View File

@@ -1,8 +0,0 @@
"""
Core 模块 - AI 核心能力
"""
from . import tools
__all__ = [
"tools",
]

View File

@@ -1,130 +0,0 @@
"""
Agent 工具模块
"""
from .registry import ToolRegistry, ToolMetadata, SecurityLevel, global_registry
from . import impl
# 导入所有工具函数和定义
from .impl import (
# 文件操作
read_file,
write_file,
list_dir,
delete_file,
search_files,
READ_FILE_TOOL,
WRITE_FILE_TOOL,
LIST_DIR_TOOL,
DELETE_FILE_TOOL,
SEARCH_FILES_TOOL,
# 代码执行
execute_python,
execute_javascript,
execute_bash,
EXECUTE_PYTHON_TOOL,
EXECUTE_JAVASCRIPT_TOOL,
EXECUTE_BASH_TOOL,
# 网页
web_fetch,
web_search,
WEB_FETCH_TOOL,
WEB_SEARCH_TOOL,
# HTTP
http_request,
http_get,
http_post,
http_put,
http_delete,
HTTP_REQUEST_TOOL,
# 通知
send_notification,
send_email,
send_webhook,
SEND_NOTIFICATION_TOOL,
# 时间
get_current_time,
format_time,
GET_CURRENT_TIME_TOOL,
)
def register_all_tools(registry: ToolRegistry = None):
"""
注册所有工具到注册表
Args:
registry: 工具注册表,默认使用全局注册表
"""
reg = registry or global_registry
# 文件操作
reg.register("read_file", read_file, READ_FILE_TOOL["description"], "safe", parameters=READ_FILE_TOOL["parameters"])
reg.register("write_file", write_file, WRITE_FILE_TOOL["description"], "review", parameters=WRITE_FILE_TOOL["parameters"])
reg.register("list_dir", list_dir, LIST_DIR_TOOL["description"], "safe", parameters=LIST_DIR_TOOL["parameters"])
reg.register("delete_file", delete_file, DELETE_FILE_TOOL["description"], "danger", parameters=DELETE_FILE_TOOL["parameters"])
reg.register("search_files", search_files, SEARCH_FILES_TOOL["description"], "safe", parameters=SEARCH_FILES_TOOL["parameters"])
# 代码执行
reg.register("execute_python", execute_python, EXECUTE_PYTHON_TOOL["description"], "review", parameters=EXECUTE_PYTHON_TOOL["parameters"])
reg.register("execute_javascript", execute_javascript, EXECUTE_JAVASCRIPT_TOOL["description"], "review", parameters=EXECUTE_JAVASCRIPT_TOOL["parameters"])
reg.register("execute_bash", execute_bash, EXECUTE_BASH_TOOL["description"], "danger", parameters=EXECUTE_BASH_TOOL["parameters"])
# 网页
reg.register("web_fetch", web_fetch, WEB_FETCH_TOOL["description"], "safe", parameters=WEB_FETCH_TOOL["parameters"])
reg.register("web_search", web_search, WEB_SEARCH_TOOL["description"], "safe", parameters=WEB_SEARCH_TOOL["parameters"])
# HTTP
reg.register("http_request", http_request, HTTP_REQUEST_TOOL["description"], "safe", parameters=HTTP_REQUEST_TOOL["parameters"])
# 通知
reg.register("send_notification", send_notification, SEND_NOTIFICATION_TOOL["description"], "safe", parameters=SEND_NOTIFICATION_TOOL["parameters"])
# 时间
reg.register("get_current_time", get_current_time, GET_CURRENT_TIME_TOOL["description"], "safe", parameters=GET_CURRENT_TIME_TOOL["parameters"])
return reg
# 注册所有工具
register_all_tools(global_registry)
__all__ = [
"ToolRegistry",
"ToolMetadata",
"SecurityLevel",
"global_registry",
"register_all_tools",
"impl",
# 所有工具函数
"read_file",
"write_file",
"list_dir",
"delete_file",
"search_files",
"execute_python",
"execute_javascript",
"execute_bash",
"web_fetch",
"web_search",
"http_request",
"http_get",
"http_post",
"http_put",
"http_delete",
"send_notification",
"send_email",
"send_webhook",
"get_current_time",
"format_time",
]

View File

@@ -1,100 +0,0 @@
"""
工具实现模块
"""
from .files import (
read_file,
write_file,
list_dir,
delete_file,
search_files,
READ_FILE_TOOL,
WRITE_FILE_TOOL,
LIST_DIR_TOOL,
DELETE_FILE_TOOL,
SEARCH_FILES_TOOL,
)
from .executor import (
execute_python,
execute_javascript,
execute_bash,
EXECUTE_PYTHON_TOOL,
EXECUTE_JAVASCRIPT_TOOL,
EXECUTE_BASH_TOOL,
)
from .web import (
web_fetch,
web_search,
WEB_FETCH_TOOL,
WEB_SEARCH_TOOL,
)
from .http import (
http_request,
http_get,
http_post,
http_put,
http_delete,
HTTP_REQUEST_TOOL,
)
from .notify import (
send_notification,
send_email,
send_webhook,
SEND_NOTIFICATION_TOOL,
)
from .time_tool import (
get_current_time,
format_time,
GET_CURRENT_TIME_TOOL,
)
__all__ = [
# 文件操作
"read_file",
"write_file",
"list_dir",
"delete_file",
"search_files",
"READ_FILE_TOOL",
"WRITE_FILE_TOOL",
"LIST_DIR_TOOL",
"DELETE_FILE_TOOL",
"SEARCH_FILES_TOOL",
# 代码执行
"execute_python",
"execute_javascript",
"execute_bash",
"EXECUTE_PYTHON_TOOL",
"EXECUTE_JAVASCRIPT_TOOL",
"EXECUTE_BASH_TOOL",
# 网页
"web_fetch",
"web_search",
"WEB_FETCH_TOOL",
"WEB_SEARCH_TOOL",
# HTTP
"http_request",
"http_get",
"http_post",
"http_put",
"http_delete",
"HTTP_REQUEST_TOOL",
# 通知
"send_notification",
"send_email",
"send_webhook",
"SEND_NOTIFICATION_TOOL",
# 时间
"get_current_time",
"format_time",
"GET_CURRENT_TIME_TOOL",
]

View File

@@ -1,334 +0,0 @@
"""
代码执行工具
提供安全的Python、JavaScript、Bash代码执行
"""
import subprocess
import tempfile
import os
import shutil
from typing import Dict, Any, Optional
from pathlib import Path
class ExecutorConfig:
"""执行器配置"""
MAX_EXECUTION_TIME = 30 # 最大执行时间(秒)
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小(1MB)
MAX_MEMORY_MB = 256 # 最大内存(MB)
ALLOWED_PYTHON_PACKAGES = [] # 允许的Python包(空=仅标准库)
class CodeExecutor:
"""
代码执行器 - 在沙盒环境中执行代码
安全特性:
- 临时目录隔离
- 超时控制
- 输出大小限制
- 环境变量限制
"""
def __init__(self, config: Optional[ExecutorConfig] = None):
self.config = config or ExecutorConfig()
self.temp_dir: Optional[str] = None
def _setup_temp_dir(self) -> str:
"""创建临时目录"""
self.temp_dir = tempfile.mkdtemp(prefix="executor_")
return self.temp_dir
def _cleanup(self):
"""清理临时目录"""
if self.temp_dir and os.path.exists(self.temp_dir):
try:
shutil.rmtree(self.temp_dir)
except Exception:
pass
def _get_safe_env(self) -> Dict[str, str]:
"""获取安全的环境变量"""
return {
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
"LANG": "en_US.UTF-8",
"HOME": self.temp_dir or "/tmp",
"TMPDIR": self.temp_dir or "/tmp",
}
def execute_python(
self,
code: str,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""
执行Python代码
Args:
code: Python代码
timeout: 超时时间(秒)
Returns:
执行结果
"""
timeout = timeout or self.config.MAX_EXECUTION_TIME
self._setup_temp_dir()
try:
# 写入临时文件
temp_file = os.path.join(self.temp_dir, "code.py")
with open(temp_file, "w", encoding="utf-8") as f:
f.write(code)
# 执行
result = subprocess.run(
["python", temp_file],
capture_output=True,
timeout=timeout,
cwd=self.temp_dir,
env=self._get_safe_env(),
)
return self._process_result(result)
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timeout ({timeout}s)",
"language": "python"
}
except FileNotFoundError:
return {
"success": False,
"error": "Python not installed",
"language": "python"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"language": "python"
}
finally:
self._cleanup()
def execute_javascript(
self,
code: str,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""
执行JavaScript代码
Args:
code: JavaScript代码
timeout: 超时时间(秒)
Returns:
执行结果
"""
timeout = timeout or self.config.MAX_EXECUTION_TIME
self._setup_temp_dir()
try:
# 写入临时文件
temp_file = os.path.join(self.temp_dir, "code.js")
with open(temp_file, "w", encoding="utf-8") as f:
f.write(code)
# 执行
result = subprocess.run(
["node", temp_file],
capture_output=True,
timeout=timeout,
cwd=self.temp_dir,
)
return self._process_result(result)
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timeout ({timeout}s)",
"language": "javascript"
}
except FileNotFoundError:
return {
"success": False,
"error": "Node.js not installed",
"language": "javascript"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"language": "javascript"
}
finally:
self._cleanup()
def execute_bash(
self,
command: str,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""
执行Bash命令
Args:
command: Bash命令
timeout: 超时时间(秒)
Returns:
执行结果
"""
timeout = timeout or self.config.MAX_EXECUTION_TIME
self._setup_temp_dir()
# 安全检查:禁止的危险命令
dangerous_patterns = [
"rm -rf /",
"mkfs",
"dd if=",
">:/dev/sd",
"chmod 777 /",
"chown -R",
]
for pattern in dangerous_patterns:
if pattern in command:
return {
"success": False,
"error": f"Dangerous command blocked: {pattern}",
"language": "bash"
}
try:
# 执行
result = subprocess.run(
["bash", "-c", command],
capture_output=True,
timeout=timeout,
cwd=self.temp_dir,
env=self._get_safe_env(),
)
return self._process_result(result)
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timeout ({timeout}s)",
"language": "bash"
}
except FileNotFoundError:
return {
"success": False,
"error": "Bash not installed",
"language": "bash"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"language": "bash"
}
finally:
self._cleanup()
def _process_result(self, result: subprocess.CompletedProcess) -> Dict[str, Any]:
"""处理执行结果"""
stdout = result.stdout.decode("utf-8", errors="replace")
stderr = result.stderr.decode("utf-8", errors="replace")
# 截断输出
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
return {
"success": result.returncode == 0,
"output": stdout,
"error": stderr if result.returncode != 0 else None,
"exit_code": result.returncode
}
# 全局执行器实例
executor = CodeExecutor()
# 便捷函数
def execute_python(code: str, timeout: int = 30) -> Dict[str, Any]:
"""执行Python代码"""
return executor.execute_python(code, timeout)
def execute_javascript(code: str, timeout: int = 30) -> Dict[str, Any]:
"""执行JavaScript代码"""
return executor.execute_javascript(code, timeout)
def execute_bash(command: str, timeout: int = 30) -> Dict[str, Any]:
"""执行Bash命令"""
return executor.execute_bash(command, timeout)
# 工具定义
EXECUTE_PYTHON_TOOL = {
"name": "execute_python",
"description": "Execute Python code in a sandboxed environment. Use this for Python programming tasks, calculations, and data processing.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The Python code to execute"
},
"timeout": {
"type": "integer",
"description": "Execution timeout in seconds (default: 30, max: 60)",
"default": 30
}
},
"required": ["code"]
}
}
EXECUTE_JAVASCRIPT_TOOL = {
"name": "execute_javascript",
"description": "Execute JavaScript code in a sandboxed environment. Use this for JavaScript programming tasks.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The JavaScript code to execute"
},
"timeout": {
"type": "integer",
"description": "Execution timeout in seconds (default: 30)",
"default": 30
}
},
"required": ["code"]
}
}
EXECUTE_BASH_TOOL = {
"name": "execute_bash",
"description": "Execute a bash command in a sandboxed environment. Use this for shell operations, file management, and system commands.",
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The bash command to execute"
},
"timeout": {
"type": "integer",
"description": "Execution timeout in seconds (default: 30)",
"default": 30
}
},
"required": ["command"]
}
}

View File

@@ -1,444 +0,0 @@
"""
文件操作工具
提供安全的文件读写、目录操作、搜索功能
"""
import os
import shutil
import glob as glob_module
from typing import Optional, List, Dict, Any
from pathlib import Path
class FileToolConfig:
"""文件工具配置"""
# 允许访问的基础目录(限制在项目内)
ALLOWED_BASE_DIRS = [
"account", # 用户工作区
"temp", # 临时文件
]
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
MAX_SEARCH_RESULTS = 100
def _resolve_safe_path(base_path: str, relative_path: str) -> str:
"""
解析安全的文件路径
确保路径不会超出基础目录
"""
# 规范化路径
full_path = os.path.normpath(os.path.join(base_path, relative_path))
# 检查是否在允许的基础目录内
path_parts = Path(full_path).parts
if len(path_parts) < 2:
raise ValueError("Invalid path: too short")
base_dir = path_parts[0]
if base_dir not in FileToolConfig.ALLOWED_BASE_DIRS and not base_dir.endswith(".py"):
# 允许 account 下的子目录
if len(path_parts) >= 2 and path_parts[0] != "account":
raise ValueError(f"Path not in allowed directories: {base_dir}")
return full_path
def read_file(file_path: str, encoding: str = "utf-8") -> Dict[str, Any]:
"""
读取文件内容
Args:
file_path: 文件路径
encoding: 文件编码
Returns:
文件内容
"""
try:
# 安全检查
full_path = _resolve_safe_path("", file_path)
if not os.path.exists(full_path):
return {
"success": False,
"error": f"File not found: {file_path}"
}
if not os.path.isfile(full_path):
return {
"success": False,
"error": f"Not a file: {file_path}"
}
# 检查文件大小
file_size = os.path.getsize(full_path)
if file_size > FileToolConfig.MAX_FILE_SIZE:
return {
"success": False,
"error": f"File too large: {file_size} bytes (max {FileToolConfig.MAX_FILE_SIZE})"
}
# 读取内容
with open(full_path, "r", encoding=encoding, errors="replace") as f:
content = f.read()
return {
"success": True,
"content": content,
"file_path": file_path,
"size": file_size,
"encoding": encoding
}
except ValueError as e:
return {
"success": False,
"error": str(e)
}
except Exception as e:
return {
"success": False,
"error": f"Read error: {str(e)}"
}
def write_file(file_path: str, content: str, encoding: str = "utf-8") -> Dict[str, Any]:
"""
写入文件内容
Args:
file_path: 文件路径
content: 文件内容
encoding: 文件编码
Returns:
写入结果
"""
try:
# 安全检查
full_path = _resolve_safe_path("", file_path)
# 检查内容大小
if len(content.encode(encoding)) > FileToolConfig.MAX_FILE_SIZE:
return {
"success": False,
"error": f"Content too large: {len(content)} bytes"
}
# 确保目录存在
os.makedirs(os.path.dirname(full_path), exist_ok=True)
# 写入内容
with open(full_path, "w", encoding=encoding) as f:
f.write(content)
return {
"success": True,
"file_path": file_path,
"bytes_written": len(content.encode(encoding))
}
except ValueError as e:
return {
"success": False,
"error": str(e)
}
except Exception as e:
return {
"success": False,
"error": f"Write error: {str(e)}"
}
def list_dir(dir_path: str = ".") -> Dict[str, Any]:
"""
列出目录内容
Args:
dir_path: 目录路径
Returns:
目录内容列表
"""
try:
full_path = _resolve_safe_path("", dir_path)
if not os.path.exists(full_path):
return {
"success": False,
"error": f"Directory not found: {dir_path}"
}
if not os.path.isdir(full_path):
return {
"success": False,
"error": f"Not a directory: {dir_path}"
}
items = []
for item in os.listdir(full_path):
item_path = os.path.join(full_path, item)
is_dir = os.path.isdir(item_path)
try:
size = 0 if is_dir else os.path.getsize(item_path)
except:
size = 0
items.append({
"name": item,
"type": "directory" if is_dir else "file",
"size": size
})
return {
"success": True,
"path": dir_path,
"items": items,
"count": len(items)
}
except ValueError as e:
return {
"success": False,
"error": str(e)
}
except Exception as e:
return {
"success": False,
"error": f"List error: {str(e)}"
}
def delete_file(file_path: str) -> Dict[str, Any]:
"""
删除文件或目录
Args:
file_path: 文件或目录路径
Returns:
删除结果
"""
try:
full_path = _resolve_safe_path("", file_path)
if not os.path.exists(full_path):
return {
"success": False,
"error": f"Path not found: {file_path}"
}
# 删除
if os.path.isfile(full_path):
os.remove(full_path)
elif os.path.isdir(full_path):
shutil.rmtree(full_path)
return {
"success": True,
"file_path": file_path,
"deleted": True
}
except ValueError as e:
return {
"success": False,
"error": str(e)
}
except Exception as e:
return {
"success": False,
"error": f"Delete error: {str(e)}"
}
def search_files(
directory: str,
pattern: str = "*",
content_pattern: Optional[str] = None,
file_only: bool = True
) -> Dict[str, Any]:
"""
搜索文件
Args:
directory: 搜索目录
pattern: 文件名匹配模式 (glob)
content_pattern: 文件内容匹配模式 (可选)
file_only: 是否只返回文件
Returns:
搜索结果
"""
try:
full_path = _resolve_safe_path("", directory)
if not os.path.exists(full_path) or not os.path.isdir(full_path):
return {
"success": False,
"error": f"Invalid directory: {directory}"
}
results = []
# 按文件名搜索
for match in glob_module.glob(os.path.join(full_path, "**", pattern), recursive=True):
if file_only and os.path.isdir(match):
continue
rel_path = os.path.relpath(match, full_path)
# 如果没有内容搜索,直接添加
if not content_pattern:
results.append({
"path": rel_path,
"name": os.path.basename(match),
"type": "directory" if os.path.isdir(match) else "file"
})
continue
# 内容搜索
if os.path.isfile(match):
try:
# 检查文件大小
if os.path.getsize(match) > FileToolConfig.MAX_FILE_SIZE:
continue
with open(match, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
if content_pattern.lower() in content.lower():
results.append({
"path": rel_path,
"name": os.path.basename(match),
"type": "file",
"match": content_pattern
})
except:
continue
# 限制结果数量
if len(results) > FileToolConfig.MAX_SEARCH_RESULTS:
results = results[:FileToolConfig.MAX_SEARCH_RESULTS]
return {
"success": True,
"directory": directory,
"pattern": pattern,
"results": results,
"count": len(results)
}
except ValueError as e:
return {
"success": False,
"error": str(e)
}
except Exception as e:
return {
"success": False,
"error": f"Search error: {str(e)}"
}
# 工具定义
READ_FILE_TOOL = {
"name": "read_file",
"description": "Read the contents of a file from the filesystem.",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The path to the file to read"
},
"encoding": {
"type": "string",
"description": "File encoding (default: utf-8)",
"default": "utf-8"
}
},
"required": ["file_path"]
}
}
WRITE_FILE_TOOL = {
"name": "write_file",
"description": "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The path to the file to write"
},
"content": {
"type": "string",
"description": "The content to write to the file"
},
"encoding": {
"type": "string",
"description": "File encoding (default: utf-8)",
"default": "utf-8"
}
},
"required": ["file_path", "content"]
}
}
LIST_DIR_TOOL = {
"name": "list_dir",
"description": "List the contents of a directory.",
"parameters": {
"type": "object",
"properties": {
"dir_path": {
"type": "string",
"description": "The path to the directory to list (default: current directory)",
"default": "."
}
}
}
}
DELETE_FILE_TOOL = {
"name": "delete_file",
"description": "Delete a file or directory.",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The path to the file or directory to delete"
}
},
"required": ["file_path"]
}
}
SEARCH_FILES_TOOL = {
"name": "search_files",
"description": "Search for files by name pattern or content.",
"parameters": {
"type": "object",
"properties": {
"directory": {
"type": "string",
"description": "The directory to search in"
},
"pattern": {
"type": "string",
"description": "Glob pattern for file names (e.g., '*.py', '*.txt')",
"default": "*"
},
"content_pattern": {
"type": "string",
"description": "Optional: search for files containing this text in their content"
},
"file_only": {
"type": "boolean",
"description": "Only return files, not directories",
"default": True
}
},
"required": ["directory"]
}
}

View File

@@ -1,271 +0,0 @@
"""
HTTP请求工具
提供通用的HTTP API调用功能
"""
import httpx
from typing import Dict, Any, Optional, List
class HTTPClientConfig:
"""HTTP客户端配置"""
DEFAULT_TIMEOUT = 30 # 默认超时(秒)
MAX_RESPONSE_SIZE = 5 * 1024 * 1024 # 最大响应大小(5MB)
MAX_REDIRECTS = 5 # 最大重定向次数
ALLOWED_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
class HTTPClient:
"""
HTTP客户端工具
安全特性:
- 只允许特定HTTP方法
- 响应大小限制
- 超时控制
- 请求/响应日志
"""
def __init__(self):
self.default_timeout = HTTPClientConfig.DEFAULT_TIMEOUT
async def request(
self,
url: str,
method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json_data: Optional[Dict[str, Any]] = None,
data: Optional[Any] = None,
timeout: Optional[int] = None,
allow_redirects: bool = True
) -> Dict[str, Any]:
"""
发送HTTP请求
Args:
url: 目标URL
method: HTTP方法
params: 查询参数
headers: 请求头
json_data: JSON请求体
data: 原始请求体
timeout: 超时时间
allow_redirects: 是否允许重定向
Returns:
响应结果
"""
# 安全检查:方法
method = method.upper()
if method not in HTTPClientConfig.ALLOWED_METHODS:
return {
"success": False,
"error": f"Method '{method}' not allowed. Allowed: {HTTPClientConfig.ALLOWED_METHODS}"
}
# 安全检查:协议
if not url.startswith(("http://", "https://")):
return {
"success": False,
"error": "Only HTTP and HTTPS protocols are allowed"
}
timeout = timeout or self.default_timeout
try:
async with httpx.AsyncClient(
timeout=timeout,
max_redirects=HTTPClientConfig.MAX_REDIRECTS if allow_redirects else 0,
follow_redirects=allow_redirects,
) as client:
response = await client.request(
method=method,
url=url,
params=params,
headers=headers,
json=json_data,
content=data,
)
# 检查响应大小
content_length = len(response.content)
if content_length > HTTPClientConfig.MAX_RESPONSE_SIZE:
return {
"success": False,
"error": f"Response too large: {content_length} bytes"
}
# 解析响应
content_type = response.headers.get("content-type", "")
if "application/json" in content_type:
try:
return {
"success": True,
"status_code": response.status_code,
"url": str(response.url),
"headers": dict(response.headers),
"json": response.json()
}
except:
pass
# 文本响应
return {
"success": True,
"status_code": response.status_code,
"url": str(response.url),
"headers": dict(response.headers),
"text": response.text[:HTTPClientConfig.MAX_RESPONSE_SIZE]
}
except httpx.TimeoutException:
return {
"success": False,
"error": f"Request timeout ({timeout}s)"
}
except httpx.InvalidURL:
return {
"success": False,
"error": "Invalid URL"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
async def get(
self,
url: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送GET请求"""
return await self.request(url, "GET", params, headers, timeout=timeout)
async def post(
self,
url: str,
json_data: Optional[Dict[str, Any]] = None,
data: Optional[Any] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送POST请求"""
return await self.request(url, "POST", None, headers, json_data, data, timeout)
async def put(
self,
url: str,
json_data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送PUT请求"""
return await self.request(url, "PUT", None, headers, json_data, None, timeout)
async def delete(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送DELETE请求"""
return await self.request(url, "DELETE", None, headers, timeout=timeout)
# 全局HTTP客户端
http_client = HTTPClient()
# 便捷函数
async def http_request(
url: str,
method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json_data: Optional[Dict[str, Any]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送HTTP请求"""
return await http_client.request(url, method, params, headers, json_data, None, timeout)
async def http_get(
url: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送GET请求"""
return await http_client.get(url, params, headers, timeout)
async def http_post(
url: str,
json_data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送POST请求"""
return await http_client.post(url, json_data, None, headers, timeout)
async def http_put(
url: str,
json_data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送PUT请求"""
return await http_client.put(url, json_data, headers, timeout)
async def http_delete(
url: str,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""发送DELETE请求"""
return await http_client.delete(url, headers, timeout)
# 工具定义
HTTP_REQUEST_TOOL = {
"name": "http_request",
"description": "Make HTTP requests to APIs. Supports GET, POST, PUT, DELETE methods with JSON data.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to request"
},
"method": {
"type": "string",
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
"default": "GET"
},
"params": {
"type": "object",
"description": "Query parameters for GET requests"
},
"headers": {
"type": "object",
"description": "Request headers"
},
"json_data": {
"type": "object",
"description": "JSON body for POST/PUT requests"
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30
}
},
"required": ["url"]
}
}

View File

@@ -1,379 +0,0 @@
"""
通知工具
提供发送通知的功能邮件、Webhook等
"""
import httpx
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum
class NotificationType(Enum):
"""通知类型"""
EMAIL = "email"
WEBHOOK = "webhook"
SMS = "sms"
DINGTALK = "dingtalk"
WECHAT = "wechat"
SLACK = "slack"
@dataclass
class NotificationConfig:
"""通知配置"""
# Email配置
smtp_host: str = ""
smtp_port: int = 587
smtp_user: str = ""
smtp_password: str = ""
from_email: str = ""
# Webhook配置
webhook_url: str = ""
webhook_secret: str = ""
# 钉钉配置
dingtalk_webhook: str = ""
# Slack配置
slack_webhook: str = ""
class NotificationTool:
"""
通知工具
支持多种通知渠道:
- Email (SMTP)
- Webhook
- 钉钉
- Slack
"""
def __init__(self, config: Optional[NotificationConfig] = None):
self.config = config or NotificationConfig()
async def send_email(
self,
to: str,
subject: str,
body: str,
cc: Optional[List[str]] = None,
is_html: bool = False
) -> Dict[str, Any]:
"""
发送邮件
Args:
to: 收件人
subject: 主题
body: 内容
cc: 抄送列表
is_html: 是否HTML格式
Returns:
发送结果
"""
if not self.config.smtp_host:
return {
"success": False,
"error": "Email not configured"
}
try:
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
# 构建邮件
msg = MIMEMultipart('alternative')
msg['Subject'] = subject
msg['From'] = self.config.from_email or self.config.smtp_user
msg['To'] = to
if cc:
msg['Cc'] = ",".join(cc)
# 添加内容
content_type = "html" if is_html else "plain"
msg.attach(MIMEText(body, content_type))
# 发送
with smtplib.SMTP(self.config.smtp_host, self.config.smtp_port) as server:
server.starttls()
server.login(self.config.smtp_user, self.config.smtp_password)
server.send_message(msg)
return {
"success": True,
"type": "email",
"to": to,
"subject": subject
}
except Exception as e:
return {
"success": False,
"error": str(e),
"type": "email"
}
async def send_webhook(
self,
url: str,
data: Dict[str, Any],
method: str = "POST",
headers: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
发送Webhook
Args:
url: Webhook URL
data: 请求数据
method: HTTP方法
headers: 请求头
Returns:
发送结果
"""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.request(
method=method,
url=url,
json=data,
headers=headers
)
return {
"success": response.status_code < 400,
"status_code": response.status_code,
"type": "webhook",
"url": url
}
except Exception as e:
return {
"success": False,
"error": str(e),
"type": "webhook"
}
async def send_dingtalk(
self,
message: str,
webhook: Optional[str] = None
) -> Dict[str, Any]:
"""
发送钉钉消息
Args:
message: 消息内容
webhook: 自定义webhook URL
Returns:
发送结果
"""
url = webhook or self.config.dingtalk_webhook
if not url:
return {
"success": False,
"error": "Dingtalk webhook not configured"
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(
url,
json={
"msgtype": "text",
"text": {
"content": message
}
}
)
result = response.json()
return {
"success": result.get("errcode") == 0,
"type": "dingtalk",
"response": result
}
except Exception as e:
return {
"success": False,
"error": str(e),
"type": "dingtalk"
}
async def send_slack(
self,
message: str,
channel: Optional[str] = None,
webhook: Optional[str] = None
) -> Dict[str, Any]:
"""
发送Slack消息
Args:
message: 消息内容
channel: 频道
webhook: 自定义webhook URL
Returns:
发送结果
"""
url = webhook or self.config.slack_webhook
if not url:
return {
"success": False,
"error": "Slack webhook not configured"
}
try:
payload = {"text": message}
if channel:
payload["channel"] = channel
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(url, json=payload)
return {
"success": response.status_code == 200,
"type": "slack",
"status_code": response.status_code
}
except Exception as e:
return {
"success": False,
"error": str(e),
"type": "slack"
}
async def send(
self,
type: str,
message: str,
**kwargs
) -> Dict[str, Any]:
"""
统一发送接口
Args:
type: 通知类型 (email, webhook, dingtalk, slack)
message: 消息内容
**kwargs: 其他参数
Returns:
发送结果
"""
type = type.lower()
if type == "email":
return await self.send_email(
to=kwargs.get("to", ""),
subject=kwargs.get("subject", "Notification"),
body=message,
cc=kwargs.get("cc")
)
elif type == "webhook":
return await self.send_webhook(
url=kwargs.get("url", ""),
data=kwargs.get("data", {"message": message})
)
elif type == "dingtalk":
return await self.send_dingtalk(
message=message,
webhook=kwargs.get("webhook")
)
elif type == "slack":
return await self.send_slack(
message=message,
channel=kwargs.get("channel"),
webhook=kwargs.get("webhook")
)
else:
return {
"success": False,
"error": f"Unknown notification type: {type}"
}
# 全局通知工具
notification_tool = NotificationTool()
# 便捷函数
async def send_notification(
type: str,
message: str,
**kwargs
) -> Dict[str, Any]:
"""发送通知"""
return await notification_tool.send(type, message, **kwargs)
async def send_email(
to: str,
subject: str,
body: str
) -> Dict[str, Any]:
"""发送邮件"""
return await notification_tool.send_email(to, subject, body)
async def send_webhook(
url: str,
data: Dict[str, Any]
) -> Dict[str, Any]:
"""发送Webhook"""
return await notification_tool.send_webhook(url, data)
# 工具定义
SEND_NOTIFICATION_TOOL = {
"name": "send_notification",
"description": "Send notifications via email, webhook, dingtalk, or slack.",
"parameters": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "Notification type: email, webhook, dingtalk, slack",
"enum": ["email", "webhook", "dingtalk", "slack"]
},
"message": {
"type": "string",
"description": "The notification message"
},
"to": {
"type": "string",
"description": "For email: recipient email address"
},
"subject": {
"type": "string",
"description": "For email: email subject"
},
"url": {
"type": "string",
"description": "For webhook: webhook URL"
},
"data": {
"type": "object",
"description": "For webhook: JSON data to send"
},
"webhook": {
"type": "string",
"description": "Custom webhook URL for dingtalk/slack"
},
"channel": {
"type": "string",
"description": "For slack: channel name"
}
},
"required": ["type", "message"]
}
}

View File

@@ -1,70 +0,0 @@
"""
时间工具
"""
from datetime import datetime
from typing import Optional, Dict, Any
def get_current_time(timezone: Optional[str] = None) -> Dict[str, Any]:
"""
获取当前时间
Args:
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
Returns:
当前时间信息
"""
now = datetime.now()
return {
"success": True,
"datetime": now.isoformat(),
"timestamp": now.timestamp(),
"date": now.strftime("%Y-%m-%d"),
"time": now.strftime("%H:%M:%S"),
"weekday": now.strftime("%A"),
"timezone": timezone or "Local Time"
}
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> Dict[str, Any]:
"""
格式化时间戳
Args:
timestamp: Unix 时间戳
format_str: 格式字符串
Returns:
格式化后的时间
"""
try:
dt = datetime.fromtimestamp(timestamp)
return {
"success": True,
"formatted": dt.strftime(format_str),
"datetime": dt.isoformat()
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# 工具定义
GET_CURRENT_TIME_TOOL = {
"name": "get_current_time",
"description": "Get the current date and time. Useful for timestamps or scheduling.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
"default": "Local"
}
}
}
}

View File

@@ -1,233 +0,0 @@
"""
网页获取工具
提供安全的网页内容抓取功能
"""
import httpx
from typing import Dict, Any, Optional
class WebToolConfig:
"""网页工具配置"""
REQUEST_TIMEOUT = 30 # 请求超时(秒)
MAX_RESPONSE_SIZE = 2 * 1024 * 1024 # 最大响应大小(2MB)
MAX_REDIRECTS = 5 # 最大重定向次数
ALLOWED_PROTOCOLS = ["http", "https"] # 允许的协议
async def web_fetch(
url: str,
method: str = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
body: Optional[str] = None,
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""
获取网页内容
Args:
url: 目标URL
method: HTTP方法
params: 查询参数
headers: 请求头
body: 请求体
timeout: 超时时间
Returns:
网页内容
"""
timeout = timeout or WebToolConfig.REQUEST_TIMEOUT
# 安全检查:协议
if not url.startswith(("http://", "https://")):
return {
"success": False,
"error": "Only HTTP and HTTPS protocols are allowed"
}
try:
async with httpx.AsyncClient(
timeout=timeout,
max_redirects=WebToolConfig.MAX_REDIRECTS,
follow_redirects=True,
) as client:
# 发送请求
response = await client.request(
method=method,
url=url,
params=params,
headers=headers,
content=body,
)
# 检查响应大小
if len(response.content) > WebToolConfig.MAX_RESPONSE_SIZE:
return {
"success": False,
"error": f"Response too large: {len(response.content)} bytes (max {WebToolConfig.MAX_RESPONSE_SIZE})"
}
# 尝试解析JSON
content_type = response.headers.get("content-type", "")
if "application/json" in content_type:
try:
data = response.json()
return {
"success": True,
"url": str(response.url),
"status_code": response.status_code,
"content_type": content_type,
"data": data,
"headers": dict(response.headers)
}
except:
pass
# 返回文本
return {
"success": True,
"url": str(response.url),
"status_code": response.status_code,
"content_type": content_type,
"content": response.text[:WebToolConfig.MAX_RESPONSE_SIZE],
"headers": dict(response.headers)
}
except httpx.TimeoutException:
return {
"success": False,
"error": f"Request timeout ({timeout}s)"
}
except httpx.RedirectLoop:
return {
"success": False,
"error": "Too many redirects"
}
except httpx.InvalidURL:
return {
"success": False,
"error": "Invalid URL"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
async def web_search(
query: str,
max_results: int = 5
) -> Dict[str, Any]:
"""
搜索网页
Args:
query: 搜索关键词
max_results: 最大结果数
Returns:
搜索结果
"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
"https://api.duckduckgo.com/",
params={
"q": query,
"format": "json",
"no_html": 1,
"skip_disambig": 1
}
)
if response.status_code == 200:
data = response.json()
results = []
if "RelatedTopics" in data:
for item in data["RelatedTopics"][:max_results]:
if "Text" in item:
text = item.get("Text", "")
results.append({
"title": text.split(" - ")[0] if " - " in text else "",
"content": text,
"url": item.get("URL", "")
})
return {
"success": True,
"query": query,
"results": results,
"count": len(results)
}
else:
return {
"success": False,
"error": f"Search API returned status {response.status_code}"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# 工具定义
WEB_FETCH_TOOL = {
"name": "web_fetch",
"description": "Fetch content from a web URL. Supports GET, POST methods and can return JSON or text content.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch"
},
"method": {
"type": "string",
"description": "HTTP method (GET, POST)",
"default": "GET"
},
"params": {
"type": "object",
"description": "Query parameters"
},
"headers": {
"type": "object",
"description": "Request headers"
},
"body": {
"type": "string",
"description": "Request body (for POST)"
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30
}
},
"required": ["url"]
}
}
WEB_SEARCH_TOOL = {
"name": "web_search",
"description": "Search the web for information. Use this when you need to find current information or facts.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 5
}
},
"required": ["query"]
}
}

View File

@@ -1,107 +0,0 @@
"""
工具注册表 - 管理所有可用工具(白名单机制)
"""
from typing import Any, Callable, Optional, Dict
from dataclasses import dataclass, asdict
from enum import Enum
class SecurityLevel(Enum):
"""工具安全等级"""
SAFE = "safe" # 安全操作
REVIEW = "review" # 需要审核
DANGER = "danger" # 危险操作
@dataclass
class ToolMetadata:
"""工具元数据"""
name: str
description: str
security_level: str
require_approval: bool = False
allowed_roles: list = None
def dict(self):
return {
"name": self.name,
"description": self.description,
"security_level": self.security_level,
"require_approval": self.require_approval
}
class ToolRegistry:
"""工具注册表"""
def __init__(self):
self._tools: Dict[str, tuple[Callable, ToolMetadata]] = {}
self._definitions: Dict[str, dict] = {}
def register(
self,
name: str,
func: Callable,
description: str = "",
security_level: str = "safe",
require_approval: bool = False,
allowed_roles: list = None,
parameters: dict = None
):
"""注册工具到白名单"""
metadata = ToolMetadata(
name=name,
description=description,
security_level=security_level,
require_approval=require_approval,
allowed_roles=allowed_roles or ["user", "admin"]
)
self._tools[name] = (func, metadata)
# 生成工具定义(用于 LLM 调用)
self._definitions[name] = {
"name": name,
"description": description,
"parameters": parameters or {
"type": "object",
"properties": {},
"required": []
}
}
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
"""获取工具函数和元数据"""
if name not in self._tools:
raise ValueError(f"Tool '{name}' not found in whitelist")
return self._tools[name]
def get_tool_definition(self, name: str) -> Optional[dict]:
"""获取工具定义(用于 LLM"""
return self._definitions.get(name)
def list_tools(self) -> list[ToolMetadata]:
"""列出所有已注册工具"""
return [meta for _, meta in self._tools.values()]
def list_definitions(self) -> list[dict]:
"""列出所有工具定义用于LLM"""
return list(self._definitions.values())
def check_permission(self, tool_name: str, user_role: str) -> bool:
"""检查用户权限"""
if tool_name not in self._tools:
return False
_, metadata = self._tools[tool_name]
return user_role in metadata.allowed_roles
def need_approval(self, tool_name: str) -> bool:
"""判断是否需要审批"""
if tool_name not in self._tools:
return False
_, metadata = self._tools[tool_name]
return metadata.require_approval
# 全局工具注册表
global_registry = ToolRegistry()

View File

@@ -1,16 +0,0 @@
"""
沙盒模块
"""
from .sandbox import (
Sandbox,
SandboxConfig,
SafeEval,
sandbox,
)
__all__ = [
"Sandbox",
"SandboxConfig",
"SafeEval",
"sandbox",
]

View File

@@ -1,267 +0,0 @@
"""
沙盒执行环境 - 在项目内构建,不依赖 Docker
提供安全的代码执行环境
"""
import subprocess
import tempfile
import os
import shutil
from typing import Dict, Any, Optional
from pathlib import Path
class SandboxConfig:
"""沙盒配置"""
# 资源限制
MAX_MEMORY_MB = 256 # 最大内存 (MB)
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
class Sandbox:
"""
沙盒执行器 - 使用 subprocess 隔离执行
安全特性:
- 内存限制
- CPU限制
- 超时控制
- 网络隔离(可选)
- 临时文件隔离
"""
def __init__(self, config: Optional[SandboxConfig] = None):
self.config = config or SandboxConfig()
self.temp_dir = None
def _setup_temp_dir(self) -> str:
"""创建临时目录"""
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
return self.temp_dir
def _cleanup(self):
"""清理临时目录"""
if self.temp_dir and os.path.exists(self.temp_dir):
try:
shutil.rmtree(self.temp_dir)
except Exception as e:
print(f"Cleanup error: {e}")
def execute(
self,
code: str,
language: str = "python",
timeout: Optional[int] = None
) -> Dict[str, Any]:
"""
在沙盒中执行代码
Args:
code: 要执行的代码
language: 语言类型 (python, javascript)
timeout: 超时时间(秒)
Returns:
执行结果
"""
timeout = timeout or self.config.MAX_EXECUTION_TIME
self._setup_temp_dir()
try:
if language == "python":
return self._execute_python(code, timeout)
elif language == "javascript":
return self._execute_javascript(code, timeout)
else:
return {
"success": False,
"error": f"Unsupported language: {language}"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
finally:
self._cleanup()
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
"""执行 Python 代码"""
# 创建临时文件
temp_file = os.path.join(self.temp_dir, "code.py")
with open(temp_file, "w", encoding="utf-8") as f:
f.write(code)
try:
# 构建命令
cmd = ["python", temp_file]
# 执行
result = subprocess.run(
cmd,
capture_output=True,
timeout=timeout,
cwd=self.temp_dir, # 限制工作目录
env=self._get_restricted_env(), # 限制环境变量
)
# 检查输出大小
stdout = result.stdout.decode("utf-8", errors="replace")
stderr = result.stderr.decode("utf-8", errors="replace")
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
return {
"success": result.returncode == 0,
"output": stdout,
"error": stderr if result.returncode != 0 else None,
"exit_code": result.returncode
}
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timeout ({timeout}s)",
"output": None
}
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
"""执行 JavaScript 代码"""
temp_file = os.path.join(self.temp_dir, "code.js")
with open(temp_file, "w", encoding="utf-8") as f:
f.write(code)
try:
# 尝试使用 node
cmd = ["node", temp_file]
result = subprocess.run(
cmd,
capture_output=True,
timeout=timeout,
cwd=self.temp_dir,
)
stdout = result.stdout.decode("utf-8", errors="replace")
stderr = result.stderr.decode("utf-8", errors="replace")
return {
"success": result.returncode == 0,
"output": stdout,
"error": stderr if result.returncode != 0 else None,
"exit_code": result.returncode
}
except subprocess.TimeoutExpired:
return {
"success": False,
"error": f"Execution timeout ({timeout}s)",
"output": None
}
except FileNotFoundError:
return {
"success": False,
"error": "Node.js not installed",
"output": None
}
def _get_restricted_env(self) -> Dict[str, str]:
"""
获取受限的环境变量
移除敏感变量,保留必要的 PATH
"""
# 保留 PATH移除其他敏感变量
safe_env = {
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
"LANG": "en_US.UTF-8",
"HOME": self.temp_dir,
"TMPDIR": self.temp_dir,
}
# 移除可能不安全的变量
unsafe_vars = [
"PYTHONPATH",
"PYTHONHOME",
"LD_PRELOAD",
"LD_LIBRARY_PATH",
]
for var in unsafe_vars:
if var in os.environ:
del os.environ[var]
return safe_env
class SafeEval:
"""
安全求值器 - 用于简单表达式计算
比沙盒更轻量,适用于不需要完全隔离的场景
"""
# 安全函数白名单
SAFE_BUILTINS = {
"abs": abs,
"min": min,
"max": max,
"sum": sum,
"len": len,
"round": round,
"pow": pow,
"print": print,
"str": str,
"int": int,
"float": float,
"bool": bool,
"list": list,
"dict": dict,
"tuple": tuple,
"set": set,
"range": range,
"enumerate": enumerate,
"zip": zip,
"map": map,
"filter": filter,
"sorted": sorted,
"reversed": reversed,
}
# 安全数学常量
SAFE_MATH = {
"pi": 3.14159265359,
"e": 2.71828182846,
}
@classmethod
def eval(cls, expression: str) -> Any:
"""
安全地求值表达式
Args:
expression: 数学表达式
Returns:
计算结果
"""
# 预处理表达式
expression = expression.replace("sqrt", "**0.5")
# 构建安全命名空间
safe_namespace = {
**cls.SAFE_BUILTINS,
**cls.SAFE_MATH,
"__builtins__": {} # 禁用__builtins__
}
try:
result = eval(expression, safe_namespace)
return result
except Exception as e:
raise ValueError(f"Evaluation error: {e}")
# 全局沙盒实例
sandbox = Sandbox()

View File

@@ -1,347 +0,0 @@
{
"version": "1.0",
"tools": [
{
"name": "read_file",
"description": "Read the contents of a file from the filesystem.",
"category": "file",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The path to the file to read"
},
"encoding": {
"type": "string",
"description": "File encoding (default: utf-8)",
"default": "utf-8"
}
},
"required": ["file_path"]
}
},
{
"name": "write_file",
"description": "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.",
"category": "file",
"security_level": "review",
"require_approval": true,
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The path to the file to write"
},
"content": {
"type": "string",
"description": "The content to write to the file"
},
"encoding": {
"type": "string",
"description": "File encoding (default: utf-8)",
"default": "utf-8"
}
},
"required": ["file_path", "content"]
}
},
{
"name": "list_dir",
"description": "List the contents of a directory.",
"category": "file",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"dir_path": {
"type": "string",
"description": "The path to the directory to list (default: current directory)",
"default": "."
}
}
}
},
{
"name": "delete_file",
"description": "Delete a file or directory.",
"category": "file",
"security_level": "danger",
"require_approval": true,
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The path to the file or directory to delete"
}
},
"required": ["file_path"]
}
},
{
"name": "search_files",
"description": "Search for files by name pattern or content.",
"category": "file",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"directory": {
"type": "string",
"description": "The directory to search in"
},
"pattern": {
"type": "string",
"description": "Glob pattern for file names (e.g., '*.py', '*.txt')",
"default": "*"
},
"content_pattern": {
"type": "string",
"description": "Optional: search for files containing this text in their content"
},
"file_only": {
"type": "boolean",
"description": "Only return files, not directories",
"default": true
}
},
"required": ["directory"]
}
},
{
"name": "execute_python",
"description": "Execute Python code in a sandboxed environment. Use this for Python programming tasks, calculations, and data processing.",
"category": "executor",
"security_level": "review",
"require_approval": true,
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The Python code to execute"
},
"timeout": {
"type": "integer",
"description": "Execution timeout in seconds (default: 30, max: 60)",
"default": 30
}
},
"required": ["code"]
}
},
{
"name": "execute_javascript",
"description": "Execute JavaScript code in a sandboxed environment. Use this for JavaScript programming tasks.",
"category": "executor",
"security_level": "review",
"require_approval": true,
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The JavaScript code to execute"
},
"timeout": {
"type": "integer",
"description": "Execution timeout in seconds (default: 30)",
"default": 30
}
},
"required": ["code"]
}
},
{
"name": "execute_bash",
"description": "Execute a bash command in a sandboxed environment. Use this for shell operations, file management, and system commands.",
"category": "executor",
"security_level": "danger",
"require_approval": true,
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The bash command to execute"
},
"timeout": {
"type": "integer",
"description": "Execution timeout in seconds (default: 30)",
"default": 30
}
},
"required": ["command"]
}
},
{
"name": "web_fetch",
"description": "Fetch content from a web URL. Supports GET, POST methods and can return JSON or text content.",
"category": "web",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch"
},
"method": {
"type": "string",
"description": "HTTP method (GET, POST)",
"default": "GET"
},
"params": {
"type": "object",
"description": "Query parameters"
},
"headers": {
"type": "object",
"description": "Request headers"
},
"body": {
"type": "string",
"description": "Request body (for POST)"
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30
}
},
"required": ["url"]
}
},
{
"name": "web_search",
"description": "Search the web for information. Use this when you need to find current information or facts.",
"category": "web",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 5
}
},
"required": ["query"]
}
},
{
"name": "http_request",
"description": "Make HTTP requests to APIs. Supports GET, POST, PUT, DELETE methods with JSON data.",
"category": "http",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to request"
},
"method": {
"type": "string",
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
"default": "GET"
},
"params": {
"type": "object",
"description": "Query parameters for GET requests"
},
"headers": {
"type": "object",
"description": "Request headers"
},
"json_data": {
"type": "object",
"description": "JSON body for POST/PUT requests"
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30
}
},
"required": ["url"]
}
},
{
"name": "send_notification",
"description": "Send notifications via email, webhook, dingtalk, or slack.",
"category": "notification",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "Notification type: email, webhook, dingtalk, slack",
"enum": ["email", "webhook", "dingtalk", "slack"]
},
"message": {
"type": "string",
"description": "The notification message"
},
"to": {
"type": "string",
"description": "For email: recipient email address"
},
"subject": {
"type": "string",
"description": "For email: email subject"
},
"url": {
"type": "string",
"description": "For webhook: webhook URL"
},
"data": {
"type": "object",
"description": "For webhook: JSON data to send"
},
"webhook": {
"type": "string",
"description": "Custom webhook URL for dingtalk/slack"
},
"channel": {
"type": "string",
"description": "For slack: channel name"
}
},
"required": ["type", "message"]
}
},
{
"name": "get_current_time",
"description": "Get the current date and time. Useful for timestamps or scheduling.",
"category": "system",
"security_level": "safe",
"require_approval": false,
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
"default": "Local"
}
}
}
}
]
}

View File

@@ -1,63 +0,0 @@
"""
LLM 工厂 - 创建不同提供商的 LLM 实例
"""
from typing import Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
class LLMFactory:
"""LLM 工厂类"""
def __init__(
self,
provider: str = "openai",
openai_api_key: Optional[str] = None,
anthropic_api_key: Optional[str] = None,
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: int = 2000
):
self.provider = provider
self.openai_api_key = openai_api_key
self.anthropic_api_key = anthropic_api_key
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self._llm = None
def get_llm(self):
"""获取 LLM 实例"""
if self._llm is not None:
return self._llm
if self.provider == "openai":
self._llm = ChatOpenAI(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
api_key=self.openai_api_key
)
elif self.provider == "anthropic":
self._llm = ChatAnthropic(
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
anthropic_api_key=self.anthropic_api_key
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
return self._llm
def set_model(self, model: str):
"""设置模型"""
self.model = model
self._llm = None # 重置 LLM 实例
def set_temperature(self, temperature: float):
"""设置温度"""
self.temperature = temperature
if self._llm:
self._llm.temperature = temperature

View File

@@ -1,58 +1,20 @@
"""
X-Agents Python Agent Service
智能体引擎服务入口
FastAPI Agent Engine Server
"""
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
import time
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from app.api import routes
from app.agent.core.agent import AgentManager
from app.security.audit import AuditLogger
from app.agent.core import AgentCore, Supervisor, AgentConfig
from app.agent.llm import LLMFactory
# 全局组件
agent_manager: AgentManager = None
audit_logger: AuditLogger = None
app = FastAPI(title="X-Agents Python Engine", version="1.0.0")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
global agent_manager, audit_logger
# 启动时初始化
audit_logger = AuditLogger()
# 初始化 Agent 管理器
agent_manager = AgentManager(
llm_provider=os.getenv("LLM_PROVIDER", "openai"),
openai_api_key=os.getenv("OPENAI_API_KEY"),
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
)
# 加载 Agent 配置
await agent_manager.load_agents()
print("Agent service started successfully")
yield
# 关闭时清理
print("Agent service shutting down")
# 创建 FastAPI 应用
app = FastAPI(
title="X-Agents Agent Service",
description="AI Agent Engine for X-Agents Platform",
version="1.0.0",
lifespan=lifespan,
)
# CORS 中间件
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -61,24 +23,180 @@ app.add_middleware(
allow_headers=["*"],
)
# 注册路由
app.include_router(routes.router, prefix="/agent", tags=["Agent"])
# === 请求/响应模型 ===
class ChatRequest(BaseModel):
"""对话请求"""
agent_id: int
message: str
user_id: int = 1
session_id: Optional[str] = None
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"service": "agent",
"version": "1.0.0"
class TeamChatRequest(BaseModel):
"""多智能体群聊请求"""
supervisor_agent_id: int
member_agent_ids: list[int]
message: str
user_id: int = 1
session_id: Optional[str] = None
strategy: str = "parallel"
class ChatResponse(BaseModel):
"""对话响应"""
agent_id: int
response: str
tool_calls: list = []
tokens_used: int = 0
duration_ms: int = 0
session_id: Optional[str] = None
# === 模拟数据存储 ===
# TODO: 后续替换为从数据库加载
_mock_agents = {
1: {
"id": 1,
"name": "数据分析助手",
"role_description": "你是一个专业的数据分析助手,擅长分析数据、生成报告。",
"model_provider": "openai",
"model_name": "gpt-4",
"skills": [1, 2]
},
2: {
"id": 2,
"name": "代码审查助手",
"role_description": "你是一个专业的代码审查助手擅长审查代码、发现bug。",
"model_provider": "openai",
"model_name": "gpt-4",
"skills": [3]
}
}
def get_agent_config(agent_id: int) -> AgentConfig:
"""获取智能体配置"""
agent_data = _mock_agents.get(agent_id)
if not agent_data:
raise HTTPException(status_code=404, detail="Agent not found")
return AgentConfig(
id=agent_data["id"],
name=agent_data["name"],
role_description=agent_data["role_description"],
model_provider=agent_data["model_provider"],
model_name=agent_data["model_name"],
skills=agent_data.get("skills", [])
)
# === API 路由 ===
@app.get("/")
async def root():
"""根路径"""
return {"message": "X-Agents Python Engine", "version": "1.0.0"}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/agent/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""
单智能体对话
"""
start_time = time.time()
# 获取智能体配置
try:
config = get_agent_config(request.agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# 创建智能体实例
agent = AgentCore(config)
# 生成 session_id
session_id = request.session_id or f"session_{int(time.time())}"
# 执行对话
try:
result = await agent.run(request.message, request.user_id, session_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
duration_ms = int((time.time() - start_time) * 1000)
return ChatResponse(
agent_id=request.agent_id,
response=result.content,
tool_calls=result.tool_calls,
tokens_used=result.tokens_used,
duration_ms=duration_ms,
session_id=session_id
)
@app.post("/agent/team/chat")
async def team_chat(request: TeamChatRequest):
"""
多智能体群聊
"""
start_time = time.time()
# 创建主智能体
try:
supervisor_config = get_agent_config(request.supervisor_agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
supervisor_agent = AgentCore(supervisor_config)
# 创建子智能体
members = []
for member_id in request.member_agent_ids:
try:
member_config = get_agent_config(member_id)
members.append(AgentCore(member_config))
except:
continue
if not members:
raise HTTPException(status_code=400, detail="No valid member agents")
# 创建调度器
supervisor = Supervisor(supervisor_agent, members, request.strategy)
# 生成 session_id
session_id = request.session_id or f"team_session_{int(time.time())}"
# 执行群聊
try:
result = await supervisor.run(request.message, request.user_id, session_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
duration_ms = int((time.time() - start_time) * 1000)
return {
"message": "X-Agents Agent Service",
"docs": "/docs"
"supervisor_agent_id": request.supervisor_agent_id,
"response": result["main_response"],
"subtask_results": result["subtask_results"],
"strategy": result["strategy"],
"duration_ms": duration_ms,
"session_id": session_id
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("AGENT_PORT", "8081"))
uvicorn.run(app, host="0.0.0.0", port=port)

View File

@@ -1,104 +0,0 @@
"""
审批服务 - 处理工具执行的审批流程
"""
import uuid
from datetime import datetime
from typing import Any, Dict, Optional
from enum import Enum
class ApprovalStatus(Enum):
"""审批状态"""
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
class ApprovalService:
"""审批服务"""
def __init__(self):
# 待审批队列
self.pending: Dict[str, dict] = {}
# 审批结果
self.results: Dict[str, ApprovalStatus] = {}
async def request_approval(
self,
tool_name: str,
params: dict,
user_id: str,
agent_id: str,
reason: str
) -> str:
"""
请求审批
Returns:
request_id: 审批请求ID
"""
request_id = str(uuid.uuid4())
request = {
"request_id": request_id,
"tool_name": tool_name,
"params": params,
"user_id": user_id,
"agent_id": agent_id,
"reason": reason,
"status": ApprovalStatus.PENDING,
"created_at": datetime.now().isoformat()
}
self.pending[request_id] = request
self.results[request_id] = ApprovalStatus.PENDING
# TODO: 通知 Go 后端有新审批
return request_id
async def check_approval(self, request_id: str, timeout: int = 300) -> bool:
"""
检查审批状态
Args:
request_id: 审批请求ID
timeout: 超时时间(秒)
Returns:
是否已批准
"""
import asyncio
start = datetime.now()
while (datetime.now() - start).seconds < timeout:
status = self.results.get(request_id)
if status == ApprovalStatus.APPROVED:
return True
elif status == ApprovalStatus.REJECTED:
return False
await asyncio.sleep(1)
raise TimeoutError("Approval request timeout")
async def approve(self, request_id: str):
"""批准请求"""
if request_id in self.pending:
self.pending[request_id]["status"] = ApprovalStatus.APPROVED
self.results[request_id] = ApprovalStatus.APPROVED
async def reject(self, request_id: str):
"""拒绝请求"""
if request_id in self.pending:
self.pending[request_id]["status"] = ApprovalStatus.REJECTED
self.results[request_id] = ApprovalStatus.REJECTED
def get_pending(self) -> list[dict]:
"""获取待审批列表"""
return [
req for req in self.pending.values()
if req["status"] == ApprovalStatus.PENDING
]

View File

@@ -1,81 +0,0 @@
"""
审计日志 - 记录所有 Agent 操作
"""
import json
from datetime import datetime
from typing import Any, Dict, Optional
from pathlib import Path
class AuditLogger:
"""审计日志记录器"""
def __init__(self, log_file: str = "audit.log"):
self.log_file = log_file
def log(
self,
action: str,
agent_id: str = "",
session_id: str = "",
user_id: str = "",
details: Dict[str, Any] = None,
result: str = "success"
):
"""记录审计日志"""
entry = {
"timestamp": datetime.now().isoformat(),
"action": action,
"agent_id": agent_id,
"session_id": session_id,
"user_id": user_id,
"details": details or {},
"result": result
}
# 写入文件
self._write_log(entry)
# TODO: 发送到 Go 后端
def log_tool_execution(
self,
tool_name: str,
params: Dict[str, Any],
user_id: str,
agent_id: str,
approved: bool,
result: Any
):
"""记录工具执行"""
self.log(
action="tool_execution",
agent_id=agent_id,
user_id=user_id,
details={
"tool_name": tool_name,
"params": params,
"approved": approved,
"result_preview": str(result)[:200] if result else None
},
result="approved" if approved else "pending_approval"
)
def log_error(self, action: str, error: str, **kwargs):
"""记录错误"""
self.log(
action=action,
details={"error": error, **kwargs},
result="error"
)
def _write_log(self, entry: dict):
"""写入日志文件"""
try:
log_path = Path(self.log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
except Exception as e:
print(f"Failed to write audit log: {e}")

View File

@@ -1,19 +1,8 @@
# 核心依赖
fastapi>=0.100.0
uvicorn>=0.20.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
httpx>=0.24.0
aiohttp>=3.8.0
python-multipart>=0.0.5
# LLM 支持
openai>=1.0.0
anthropic>=0.18.0
langchain-core>=0.1.0
langchain-openai>=0.0.2
# 可选:向量数据库
chromadb>=0.4.0
# Redis
redis>=4.5.0
python-dotenv>=1.0.0
aiohttp>=3.8.0
redis>=5.0.0