- 添加多 Agent 图协作框架 (graph, supervisor, workers) - 添加迭代器和集成模块 - 添加多 Agent 规划文档 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
263 lines
9.1 KiB
Python
263 lines
9.1 KiB
Python
"""
|
||
Supervisor Agent - 负责任务规划和分发
|
||
"""
|
||
import json
|
||
import re
|
||
from typing import Optional
|
||
from langchain_core.language_models import BaseChatModel
|
||
from langchain_core.output_parsers import PydanticOutputParser
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
|
||
from .types import AgentState, TaskItem, AgentType, SupervisorDecision
|
||
from .prompts import SUPERVISOR_SYSTEM_PROMPT
|
||
|
||
|
||
class SupervisorAgent:
|
||
"""Supervisor Agent - 负责任务规划和分发"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm: BaseChatModel,
|
||
max_iterations: int = 3,
|
||
max_tasks: int = 10
|
||
):
|
||
self.llm = llm
|
||
self.max_iterations = max_iterations
|
||
self.max_tasks = max_tasks
|
||
|
||
def create_node(self):
|
||
"""创建 LangGraph 节点"""
|
||
return self._supervisor_node
|
||
|
||
async def _supervisor_node(self, state: AgentState) -> dict:
|
||
"""Supervisor 节点逻辑"""
|
||
|
||
# 首次调用:分析任务并生成计划
|
||
if not state.get("task_plan"):
|
||
decision = await self._plan_tasks(
|
||
task=state["original_task"],
|
||
progress="这是任务的开始",
|
||
context=state.get("shared_context", {})
|
||
)
|
||
|
||
return {
|
||
"task_plan": decision.task_plan,
|
||
"next_node": decision.next_worker,
|
||
"current_task_index": 0,
|
||
"shared_context": {
|
||
**state.get("shared_context", {}),
|
||
"task_analysis": decision.analysis
|
||
}
|
||
}
|
||
|
||
# 非首次调用:检查任务状态,决定下一步
|
||
current_task_index = state.get("current_task_index", 0)
|
||
task_plan = state.get("task_plan", [])
|
||
|
||
# 获取当前任务
|
||
if current_task_index >= len(task_plan):
|
||
# 所有任务完成,进入汇总
|
||
return {
|
||
"next_node": "aggregator",
|
||
"shared_context": state.get("shared_context", {})
|
||
}
|
||
|
||
current_task = task_plan[current_task_index]
|
||
|
||
# 检查当前任务状态
|
||
if current_task.status == "completed":
|
||
# 当前任务完成,检查是否还有更多任务
|
||
if current_task_index + 1 < len(task_plan):
|
||
next_index = current_task_index + 1
|
||
next_task = task_plan[next_index]
|
||
return {
|
||
"current_task_index": next_index,
|
||
"next_node": next_task.assigned_agent,
|
||
"iteration": state.get("iteration", 0),
|
||
"shared_context": state.get("shared_context", {})
|
||
}
|
||
else:
|
||
# 所有任务完成,进入汇总
|
||
return {
|
||
"next_node": "aggregator",
|
||
"shared_context": state.get("shared_context", {})
|
||
}
|
||
|
||
elif current_task.status == "failed":
|
||
# 任务失败,检查是否超过最大重试
|
||
if current_task.retry_count >= self.max_iterations:
|
||
# 超过最大重试,进入汇总(标记失败)
|
||
return {
|
||
"next_node": "aggregator",
|
||
"status": "failed",
|
||
"shared_context": state.get("shared_context", {})
|
||
}
|
||
else:
|
||
# 重试当前任务
|
||
return {
|
||
"next_node": current_task.assigned_agent,
|
||
"iteration": state.get("iteration", 0) + 1,
|
||
"shared_context": state.get("shared_context", {})
|
||
}
|
||
|
||
elif current_task.status == "needs_retry":
|
||
# 需要重试(来自 review)
|
||
return {
|
||
"next_node": current_task.assigned_agent,
|
||
"iteration": state.get("iteration", 0) + 1,
|
||
"shared_context": state.get("shared_context", {})
|
||
}
|
||
|
||
# 默认继续执行
|
||
return {
|
||
"next_node": state.get("next_node", "aggregator"),
|
||
"shared_context": state.get("shared_context", {})
|
||
}
|
||
|
||
async def _plan_tasks(self, task: str, progress: str, context: dict) -> SupervisorDecision:
|
||
"""调用 LLM 生成任务计划"""
|
||
|
||
# 格式化 prompt
|
||
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else "无"
|
||
prompt = SUPERVISOR_SYSTEM_PROMPT.format(
|
||
task=task,
|
||
progress=progress,
|
||
context=context_str
|
||
)
|
||
|
||
# 调用 LLM
|
||
response = await self.llm.ainvoke([
|
||
SystemMessage(content=prompt),
|
||
HumanMessage(content="请分析任务并制定执行计划。")
|
||
])
|
||
|
||
# 解析 LLM 输出
|
||
decision = self._parse_response(response.content, task)
|
||
|
||
return decision
|
||
|
||
def _parse_response(self, response: str, original_task: str) -> SupervisorDecision:
|
||
"""解析 LLM 响应为结构化决策"""
|
||
|
||
try:
|
||
# 尝试提取 JSON
|
||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||
if json_match:
|
||
data = json.loads(json_match.group())
|
||
else:
|
||
raise ValueError("No JSON found")
|
||
|
||
# 解析任务计划
|
||
task_plan = []
|
||
for i, item in enumerate(data.get("task_plan", [])):
|
||
task = TaskItem(
|
||
id=item.get("id", f"task_{i+1}"),
|
||
description=item.get("description", ""),
|
||
assigned_agent=AgentType(item.get("assigned_agent", "coder")),
|
||
status="pending"
|
||
)
|
||
task_plan.append(task)
|
||
|
||
# 确定下一个 Worker
|
||
next_worker = data.get("next_worker", "research")
|
||
if isinstance(next_worker, dict):
|
||
next_worker = next_worker.get("assigned_agent", "research")
|
||
|
||
return SupervisorDecision(
|
||
analysis=data.get("analysis", "任务分析"),
|
||
task_plan=task_plan,
|
||
need_aggregation=data.get("need_aggregation", True),
|
||
next_worker=AgentType(next_worker)
|
||
)
|
||
|
||
except Exception as e:
|
||
# 解析失败,创建默认计划
|
||
return self._create_default_plan(original_task)
|
||
|
||
def _create_default_plan(self, task: str) -> SupervisorDecision:
|
||
"""创建默认任务计划"""
|
||
|
||
task_lower = task.lower()
|
||
|
||
# 根据任务关键词判断
|
||
if any(keyword in task_lower for keyword in ["搜索", "查找", "调研", "研究", "research", "search"]):
|
||
assigned_agent = AgentType.RESEARCH
|
||
elif any(keyword in task_lower for keyword in ["代码", "写", "开发", "code", "program", "写代码"]):
|
||
assigned_agent = AgentType.CODER
|
||
else:
|
||
assigned_agent = AgentType.CODER
|
||
|
||
# 创建默认任务
|
||
task_item = TaskItem(
|
||
id="task_1",
|
||
description=task,
|
||
assigned_agent=assigned_agent,
|
||
status="pending"
|
||
)
|
||
|
||
return SupervisorDecision(
|
||
analysis="简单任务,直接分配给合适的 Agent 执行",
|
||
task_plan=[task_item],
|
||
need_aggregation=True,
|
||
next_worker=assigned_agent
|
||
)
|
||
|
||
|
||
class ResultAggregator:
|
||
"""结果聚合器 - 汇总多个任务的结果"""
|
||
|
||
def __init__(self, llm: BaseChatModel):
|
||
self.llm = llm
|
||
|
||
def create_node(self):
|
||
"""创建 LangGraph 节点"""
|
||
return self._aggregate_node
|
||
|
||
async def _aggregate_node(self, state: AgentState) -> dict:
|
||
"""聚合节点逻辑"""
|
||
|
||
# 准备任务计划和结果
|
||
task_plan = state.get("task_plan", [])
|
||
results = state.get("results", {})
|
||
original_task = state.get("original_task", "")
|
||
|
||
# 构建任务描述
|
||
task_descriptions = []
|
||
for task in task_plan:
|
||
task_descriptions.append(f"- {task.id}: {task.description} -> {task.status}")
|
||
|
||
# 构建结果描述
|
||
result_items = []
|
||
for task_id, result in results.items():
|
||
if isinstance(result, dict):
|
||
content = result.get("content", str(result))
|
||
else:
|
||
content = str(result)
|
||
result_items.append(f"## {task_id}\n{content}")
|
||
|
||
# 调用 LLM 汇总结果
|
||
from .prompts import AGGREGATOR_SYSTEM_PROMPT
|
||
|
||
context_str = json.dumps(state.get("shared_context", {}), ensure_ascii=False, indent=2)
|
||
|
||
prompt = AGGREGATOR_SYSTEM_PROMPT.format(
|
||
original_task=original_task,
|
||
task_plan="\n".join(task_descriptions),
|
||
results="\n\n".join(result_items) if result_items else "无结果",
|
||
context=context_str
|
||
)
|
||
|
||
response = await self.llm.ainvoke([
|
||
SystemMessage(content=prompt),
|
||
HumanMessage(content="请汇总以上结果,给出最终输出。")
|
||
])
|
||
|
||
# 检查是否有失败的任务
|
||
has_failed = any(task.status == "failed" for task in task_plan)
|
||
|
||
return {
|
||
"final_output": response.content,
|
||
"status": "failed" if has_failed else "completed",
|
||
"next_node": "__end__"
|
||
}
|