Files
X-Agents/agent/app/agent/multi/supervisor.py
DESKTOP-72TV0V4\caoxiaozhu 5ea6f0d31f feat: 新增多 Agent 协作系统
- 添加多 Agent 图协作框架 (graph, supervisor, workers)
- 添加迭代器和集成模块
- 添加多 Agent 规划文档

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 23:21:37 +08:00

263 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Supervisor Agent - 负责任务规划和分发
"""
import json
import re
from typing import Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.messages import HumanMessage, SystemMessage
from .types import AgentState, TaskItem, AgentType, SupervisorDecision
from .prompts import SUPERVISOR_SYSTEM_PROMPT
class SupervisorAgent:
"""Supervisor Agent - 负责任务规划和分发"""
def __init__(
self,
llm: BaseChatModel,
max_iterations: int = 3,
max_tasks: int = 10
):
self.llm = llm
self.max_iterations = max_iterations
self.max_tasks = max_tasks
def create_node(self):
"""创建 LangGraph 节点"""
return self._supervisor_node
async def _supervisor_node(self, state: AgentState) -> dict:
"""Supervisor 节点逻辑"""
# 首次调用:分析任务并生成计划
if not state.get("task_plan"):
decision = await self._plan_tasks(
task=state["original_task"],
progress="这是任务的开始",
context=state.get("shared_context", {})
)
return {
"task_plan": decision.task_plan,
"next_node": decision.next_worker,
"current_task_index": 0,
"shared_context": {
**state.get("shared_context", {}),
"task_analysis": decision.analysis
}
}
# 非首次调用:检查任务状态,决定下一步
current_task_index = state.get("current_task_index", 0)
task_plan = state.get("task_plan", [])
# 获取当前任务
if current_task_index >= len(task_plan):
# 所有任务完成,进入汇总
return {
"next_node": "aggregator",
"shared_context": state.get("shared_context", {})
}
current_task = task_plan[current_task_index]
# 检查当前任务状态
if current_task.status == "completed":
# 当前任务完成,检查是否还有更多任务
if current_task_index + 1 < len(task_plan):
next_index = current_task_index + 1
next_task = task_plan[next_index]
return {
"current_task_index": next_index,
"next_node": next_task.assigned_agent,
"iteration": state.get("iteration", 0),
"shared_context": state.get("shared_context", {})
}
else:
# 所有任务完成,进入汇总
return {
"next_node": "aggregator",
"shared_context": state.get("shared_context", {})
}
elif current_task.status == "failed":
# 任务失败,检查是否超过最大重试
if current_task.retry_count >= self.max_iterations:
# 超过最大重试,进入汇总(标记失败)
return {
"next_node": "aggregator",
"status": "failed",
"shared_context": state.get("shared_context", {})
}
else:
# 重试当前任务
return {
"next_node": current_task.assigned_agent,
"iteration": state.get("iteration", 0) + 1,
"shared_context": state.get("shared_context", {})
}
elif current_task.status == "needs_retry":
# 需要重试(来自 review
return {
"next_node": current_task.assigned_agent,
"iteration": state.get("iteration", 0) + 1,
"shared_context": state.get("shared_context", {})
}
# 默认继续执行
return {
"next_node": state.get("next_node", "aggregator"),
"shared_context": state.get("shared_context", {})
}
async def _plan_tasks(self, task: str, progress: str, context: dict) -> SupervisorDecision:
"""调用 LLM 生成任务计划"""
# 格式化 prompt
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else ""
prompt = SUPERVISOR_SYSTEM_PROMPT.format(
task=task,
progress=progress,
context=context_str
)
# 调用 LLM
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请分析任务并制定执行计划。")
])
# 解析 LLM 输出
decision = self._parse_response(response.content, task)
return decision
def _parse_response(self, response: str, original_task: str) -> SupervisorDecision:
"""解析 LLM 响应为结构化决策"""
try:
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
data = json.loads(json_match.group())
else:
raise ValueError("No JSON found")
# 解析任务计划
task_plan = []
for i, item in enumerate(data.get("task_plan", [])):
task = TaskItem(
id=item.get("id", f"task_{i+1}"),
description=item.get("description", ""),
assigned_agent=AgentType(item.get("assigned_agent", "coder")),
status="pending"
)
task_plan.append(task)
# 确定下一个 Worker
next_worker = data.get("next_worker", "research")
if isinstance(next_worker, dict):
next_worker = next_worker.get("assigned_agent", "research")
return SupervisorDecision(
analysis=data.get("analysis", "任务分析"),
task_plan=task_plan,
need_aggregation=data.get("need_aggregation", True),
next_worker=AgentType(next_worker)
)
except Exception as e:
# 解析失败,创建默认计划
return self._create_default_plan(original_task)
def _create_default_plan(self, task: str) -> SupervisorDecision:
"""创建默认任务计划"""
task_lower = task.lower()
# 根据任务关键词判断
if any(keyword in task_lower for keyword in ["搜索", "查找", "调研", "研究", "research", "search"]):
assigned_agent = AgentType.RESEARCH
elif any(keyword in task_lower for keyword in ["代码", "", "开发", "code", "program", "写代码"]):
assigned_agent = AgentType.CODER
else:
assigned_agent = AgentType.CODER
# 创建默认任务
task_item = TaskItem(
id="task_1",
description=task,
assigned_agent=assigned_agent,
status="pending"
)
return SupervisorDecision(
analysis="简单任务,直接分配给合适的 Agent 执行",
task_plan=[task_item],
need_aggregation=True,
next_worker=assigned_agent
)
class ResultAggregator:
"""结果聚合器 - 汇总多个任务的结果"""
def __init__(self, llm: BaseChatModel):
self.llm = llm
def create_node(self):
"""创建 LangGraph 节点"""
return self._aggregate_node
async def _aggregate_node(self, state: AgentState) -> dict:
"""聚合节点逻辑"""
# 准备任务计划和结果
task_plan = state.get("task_plan", [])
results = state.get("results", {})
original_task = state.get("original_task", "")
# 构建任务描述
task_descriptions = []
for task in task_plan:
task_descriptions.append(f"- {task.id}: {task.description} -> {task.status}")
# 构建结果描述
result_items = []
for task_id, result in results.items():
if isinstance(result, dict):
content = result.get("content", str(result))
else:
content = str(result)
result_items.append(f"## {task_id}\n{content}")
# 调用 LLM 汇总结果
from .prompts import AGGREGATOR_SYSTEM_PROMPT
context_str = json.dumps(state.get("shared_context", {}), ensure_ascii=False, indent=2)
prompt = AGGREGATOR_SYSTEM_PROMPT.format(
original_task=original_task,
task_plan="\n".join(task_descriptions),
results="\n\n".join(result_items) if result_items else "无结果",
context=context_str
)
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请汇总以上结果,给出最终输出。")
])
# 检查是否有失败的任务
has_failed = any(task.status == "failed" for task in task_plan)
return {
"final_output": response.content,
"status": "failed" if has_failed else "completed",
"next_node": "__end__"
}