feat: 新增多 Agent 协作系统

- 添加多 Agent 图协作框架 (graph, supervisor, workers)
- 添加迭代器和集成模块
- 添加多 Agent 规划文档

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-10 23:21:37 +08:00
parent ac384ce10b
commit 5ea6f0d31f
15 changed files with 2409 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
"""
多智能体系统
"""
from .types import AgentState, TaskItem, TaskStatus, AgentType, SupervisorDecision, ReviewResult
from .prompts import SUPERVISOR_SYSTEM_PROMPT, REVIEW_SYSTEM_PROMPT, RESEARCH_SYSTEM_PROMPT, CODER_SYSTEM_PROMPT, AGGREGATOR_SYSTEM_PROMPT
from .supervisor import SupervisorAgent
from .graph import create_multi_agent_graph
__all__ = [
"AgentState",
"TaskItem",
"TaskStatus",
"AgentType",
"SupervisorDecision",
"ReviewResult",
"SUPERVISOR_SYSTEM_PROMPT",
"REVIEW_SYSTEM_PROMPT",
"RESEARCH_SYSTEM_PROMPT",
"CODER_SYSTEM_PROMPT",
"AGGREGATOR_SYSTEM_PROMPT",
"SupervisorAgent",
"create_multi_agent_graph",
]

View File

@@ -0,0 +1,130 @@
"""
LangGraph 流程编排
"""
from langgraph.graph import StateGraph, END
from langgraph.graph.graph import CompiledGraph
from .types import AgentState, AgentType
from .supervisor import SupervisorAgent, ResultAggregator
from .workers.research import ResearchWorker
from .workers.coder import CoderWorker
from .workers.review import ReviewWorker
def create_multi_agent_graph(
llm,
tool_registry=None,
max_iterations: int = 3,
max_tasks: int = 10
) -> CompiledGraph:
"""创建多 Agent 流程图
Args:
llm: 语言模型实例
tool_registry: 工具注册表
max_iterations: 最大迭代次数
max_tasks: 最大任务数
Returns:
CompiledGraph: 编译后的 LangGraph
"""
# 初始化组件
supervisor = SupervisorAgent(llm, max_iterations=max_iterations, max_tasks=max_tasks)
research_worker = ResearchWorker(llm, tool_registry)
coder_worker = CoderWorker(llm, tool_registry)
review_worker = ReviewWorker(llm, tool_registry)
aggregator = ResultAggregator(llm)
# 创建图
graph = StateGraph(AgentState)
# 添加节点
graph.add_node("supervisor", supervisor.create_node())
graph.add_node(AgentType.RESEARCH, research_worker.create_node())
graph.add_node(AgentType.CODER, coder_worker.create_node())
graph.add_node(AgentType.REVIEW, review_worker.create_node())
graph.add_node("aggregator", aggregator.create_node())
# 设置入口点
graph.set_entry_point("supervisor")
# 定义条件边函数
def should_continue(state: AgentState) -> str:
"""判断是否继续执行"""
# 获取下一步节点
next_node = state.get("next_node", "aggregator")
# 如果是结束节点
if next_node in ["__end__", "aggregator"]:
return "aggregator"
# 如果是 Worker 节点
if next_node in [AgentType.RESEARCH, AgentType.CODER, AgentType.REVIEW]:
return next_node
# 如果是 supervisor
if next_node == "supervisor":
# 检查迭代次数
iteration = state.get("iteration", 0)
if iteration >= max_iterations:
return "aggregator"
return "supervisor"
# 默认进入汇总
return "aggregator"
# 添加条件边:从 supervisor 出来
graph.add_conditional_edges(
"supervisor",
should_continue,
{
"supervisor": "supervisor",
AgentType.RESEARCH: AgentType.RESEARCH,
AgentType.CODER: AgentType.CODER,
AgentType.REVIEW: AgentType.REVIEW,
"aggregator": "aggregator"
}
)
# 添加边Worker -> Review
graph.add_edge(AgentType.RESEARCH, AgentType.REVIEW)
graph.add_edge(AgentType.CODER, AgentType.REVIEW)
# 添加条件边:从 Review 出来
graph.add_conditional_edges(
AgentType.REVIEW,
should_continue,
{
"supervisor": "supervisor",
"aggregator": "aggregator"
}
)
# 添加边aggregator -> END
graph.add_edge("aggregator", END)
# 编译图
return graph.compile()
def create_simple_graph(llm, tool_registry=None) -> CompiledGraph:
"""创建简单的单 Agent 图(不经过 Supervisor"""
# 创建图
graph = StateGraph(AgentState)
# 直接使用 Coder Worker
coder_worker = CoderWorker(llm, tool_registry)
# 添加节点
graph.add_node("coder", coder_worker.create_node())
# 设置入口
graph.set_entry_point("coder")
# 添加边
graph.add_edge("coder", END)
return graph.compile()

View File

@@ -0,0 +1,223 @@
"""
多智能体系统 - 与现有系统集成
"""
import logging
from typing import Optional
from app.llm.factory import LLMFactory
from app.agent.tools.registry import ToolRegistry
from app.agent.memory.session import SessionManager
from .types import create_initial_state
from .graph import create_multi_agent_graph
logger = logging.getLogger(__name__)
class MultiAgentSystem:
"""多智能体系统 - 集成现有组件"""
def __init__(
self,
llm_provider: str = "openai",
openai_api_key: Optional[str] = None,
anthropic_api_key: Optional[str] = None,
max_iterations: int = 3,
max_tasks: int = 10
):
"""
初始化多智能体系统
Args:
llm_provider: LLM 提供商
openai_api_key: OpenAI API Key
anthropic_api_key: Anthropic API Key
max_iterations: 最大迭代次数
max_tasks: 最大任务数
"""
# 初始化 LLM Factory
self.llm_factory = LLMFactory(
provider=llm_provider,
openai_api_key=openai_api_key,
anthropic_api_key=anthropic_api_key
)
# 初始化 Tool Registry
self.tool_registry = ToolRegistry()
self._register_default_tools()
# 初始化 Session Manager
self.session_manager = SessionManager()
# 配置
self.max_iterations = max_iterations
self.max_tasks = max_tasks
# 图实例(延迟初始化)
self._graph = None
def _register_default_tools(self):
"""注册默认工具"""
try:
from app.agent.tools.impl import search, calculator, time_tool
# 安全工具
self.tool_registry.register(
name="search",
func=search.search_web,
description="Search the web for information",
security_level="safe"
)
self.tool_registry.register(
name="calculator",
func=calculator.calculate,
description="Perform mathematical calculations",
security_level="safe"
)
self.tool_registry.register(
name="get_current_time",
func=time_tool.get_current_time,
description="Get current date and time",
security_level="safe"
)
# 执行代码工具
try:
from app.agent.tools.impl import sandbox
self.tool_registry.register(
name="execute_code",
func=sandbox.sandbox.execute,
description="Execute code in sandbox",
security_level="review",
require_approval=True
)
except ImportError:
pass
except ImportError as e:
logger.warning(f"Failed to import default tools: {e}")
@property
def graph(self):
"""获取或创建 LangGraph"""
if self._graph is None:
llm = self.llm_factory.get_llm()
self._graph = create_multi_agent_graph(
llm=llm,
tool_registry=self.tool_registry,
max_iterations=self.max_iterations,
max_tasks=self.max_tasks
)
return self._graph
async def execute(self, task: str, session_id: str = None) -> dict:
"""
执行多 Agent 任务
Args:
task: 任务描述
session_id: 会话 ID可选
Returns:
dict: 执行结果
"""
# 创建初始状态
initial_state = create_initial_state(task, session_id)
try:
# 执行图
result = await self.graph.ainvoke(initial_state)
# 保存到 session
if session_id:
self.session_manager.add_message(session_id, "user", task)
self.session_manager.add_message(
session_id,
"assistant",
result.get("final_output", "")
)
return {
"success": result.get("status") != "failed",
"output": result.get("final_output", ""),
"status": result.get("status", "unknown"),
"task_plan": result.get("task_plan", []),
"results": result.get("results", {})
}
except Exception as e:
logger.error(f"Multi-agent execution failed: {e}")
return {
"success": False,
"output": f"执行失败: {str(e)}",
"status": "failed",
"error": str(e)
}
async def execute_simple(self, task: str, session_id: str = None) -> dict:
"""
执行简单任务(不使用 Supervisor
Args:
task: 任务描述
session_id: 会话 ID可选
Returns:
dict: 执行结果
"""
from .graph import create_simple_graph
# 创建简单图
llm = self.llm_factory.get_llm()
simple_graph = create_simple_graph(llm, self.tool_registry)
# 创建初始状态
initial_state = create_initial_state(task, session_id)
try:
# 执行图
result = await simple_graph.ainvoke(initial_state)
return {
"success": True,
"output": result.get("final_output", ""),
"status": result.get("status", "completed")
}
except Exception as e:
logger.error(f"Simple execution failed: {e}")
return {
"success": False,
"output": f"执行失败: {str(e)}",
"status": "failed",
"error": str(e)
}
def list_tools(self) -> list:
"""列出所有可用工具"""
return self.tool_registry.list_tools()
# 全局实例
_global_system: Optional[MultiAgentSystem] = None
def get_multi_agent_system(
llm_provider: str = "openai",
openai_api_key: str = None,
anthropic_api_key: str = None,
**kwargs
) -> MultiAgentSystem:
"""获取全局多智能体系统实例"""
global _global_system
if _global_system is None:
_global_system = MultiAgentSystem(
llm_provider=llm_provider,
openai_api_key=openai_api_key,
anthropic_api_key=anthropic_api_key,
**kwargs
)
return _global_system

View File

@@ -0,0 +1,117 @@
"""
迭代控制器
"""
from typing import Optional
class IterationController:
"""迭代控制器 - 管理任务执行的迭代"""
def __init__(
self,
max_iterations: int = 3,
max_retries_per_task: int = 2
):
"""
初始化迭代控制器
Args:
max_iterations: 全局最大迭代次数
max_retries_per_task: 每个任务的最大重试次数
"""
self.max_iterations = max_iterations
self.max_retries_per_task = max_retries_per_task
def should_continue(
self,
iteration: int,
task_status: str,
review_result: Optional[dict] = None
) -> tuple[bool, str]:
"""
判断是否继续迭代
Args:
iteration: 当前迭代次数
task_status: 任务状态
review_result: 评审结果(可选)
Returns:
(是否继续, 原因)
"""
# 超过最大迭代次数
if iteration >= self.max_iterations:
return False, "max_iterations_reached"
# 任务成功完成
if task_status == "completed":
if review_result and review_result.get("passed"):
return False, "task_completed"
elif review_result is None:
return False, "task_completed"
# 任务失败且不可重试
if task_status == "failed":
if review_result and not review_result.get("retryable", True):
return False, "task_failed_non_retryable"
# 检查重试次数
retry_count = review_result.get("retry_count", 0) if review_result else 0
if retry_count >= self.max_retries_per_task:
return False, "max_retries_reached"
# 需要重试
if review_result:
issues = review_result.get("issues", [])
if issues and not review_result.get("passed", True):
return True, "needs_retry"
return True, "continue"
def get_next_action(
self,
review_result: Optional[dict],
current_worker: str
) -> str:
"""
确定下一步动作
Args:
review_result: 评审结果
current_worker: 当前执行的 Worker
Returns:
下一个节点名称
"""
if review_result is None:
return "supervisor"
# 根据评审结果决定下一步
if review_result.get("passed"):
return "supervisor"
# 根据问题类型决定下一步
issues = review_result.get("issues", [])
high_severity = any(i.get("severity") == "high" for i in issues)
if high_severity:
# 严重问题,重新执行相同任务
return current_worker
else:
# 轻微问题,返回 Supervisor
return "supervisor"
def calculate_backoff_delay(self, retry_count: int) -> float:
"""
计算退避延迟(指数退避)
Args:
retry_count: 重试次数
Returns:
延迟时间(秒)
"""
base_delay = 1.0
max_delay = 30.0
delay = min(base_delay * (2 ** retry_count), max_delay)
return delay

View File

@@ -0,0 +1,170 @@
"""
多智能体系统 Prompt 模板
"""
# Supervisor System Prompt
SUPERVISOR_SYSTEM_PROMPT = """你是一个任务规划专家Supervisor。你的职责是将复杂任务分解为可执行的子任务并分配给合适的执行 Agent。
## 可用的 Worker Agent
- **research**: 信息搜索和调研
- **coder**: 代码编写、修改和调试
- **review**: 结果检查、质量评审
## 任务
{task}
## 当前进度
{progress}
## 共享上下文
{context}
## 请按以下步骤执行
### 步骤 1: 任务分析
分析任务的性质,确定需要哪些步骤来完成。
### 步骤 2: 任务分解
将任务分解为独立的子任务。每个子任务应该:
- 描述清晰
- 可以由单个 Agent 完成
- 有明确的完成标准
### 步骤 3: 分配 Agent
为每个子任务选择最合适的执行 Agent。
### 步骤 4: 确定执行顺序
如果有依赖关系,确定正确的执行顺序。
## 输出格式
请以 JSON 格式输出你的决策,包含以下字段:
- analysis: 任务分析
- task_plan: 任务计划数组,每个元素包含 id, description, assigned_agent
- need_aggregation: 是否需要汇总结果
- next_worker: 下一个执行的 Worker 名称 (research/coder/review)
## 注意
- 如果任务很简单,可以只分配给一个 Agent
- 如果任务需要迭代优化,确保有 review 环节
- 考虑任务之间的依赖关系
- 使用 "research"/"coder"/"review" 作为 assigned_agent 的值
"""
# Review Worker System Prompt
REVIEW_SYSTEM_PROMPT = """你是一个代码和结果评审专家Reviewer。你的职责是检查任务执行结果是否符合要求。
## 原始任务
{original_task}
## 当前任务描述
{task_description}
## 执行结果
{execution_result}
## 共享上下文
{context}
## 检查标准
1. 结果是否完整解决了原始任务?
2. 输出格式是否正确?
3. 是否存在明显的错误或遗漏?
4. 代码是否有潜在问题?
5. 是否有安全漏洞或风险?
## 输出格式
请以 JSON 格式输出评审结果:
- passed: true/false是否通过
- issues: 问题数组,每个包含 severity(high/medium/low) 和 description
- suggestions: 改进建议数组
- retryable: true/false是否可以重试
## 注意
- 如果只有轻微问题passed 可以为 true
- 如果有严重问题passed 应为 false
- 判断是否需要重试,而不是立即失败
"""
# Research Worker System Prompt
RESEARCH_SYSTEM_PROMPT = """你是一个信息搜索和调研专家Researcher。你的职责是根据任务要求搜集和整理信息。
## 任务
{task}
## 共享上下文
{context}
## 请执行以下步骤
### 1. 理解任务
明确需要搜集什么信息,信息的用途是什么。
### 2. 搜索信息
使用可用工具搜索相关信息。
### 3. 整理结果
将搜索结果整理成结构化的信息。
## 输出要求
- 提供清晰、结构化的信息整理
- 标注信息来源
- 如果无法完成任务,说明原因
"""
# Coder Worker System Prompt
CODER_SYSTEM_PROMPT = """你是一个代码编写专家Coder。你的职责是根据任务要求编写和修改代码。
## 任务
{task}
## 共享上下文
{context}
## 请执行以下步骤
### 1. 理解需求
明确需要编写什么代码,代码的用途和约束。
### 2. 编写代码
使用合适的编程语言和框架编写代码。
### 3. 代码检查
确保代码语法正确,逻辑合理。
## 输出要求
- 提供完整的、可运行的代码
- 包含必要的注释说明
- 如果需要执行代码,使用代码执行工具
"""
# Aggregator System Prompt
AGGREGATOR_SYSTEM_PROMPT = """你是一个结果汇总专家Aggregator。你的职责是将多个子任务的结果汇总成最终输出。
## 原始任务
{original_task}
## 任务计划
{task_plan}
## 执行结果
{results}
## 共享上下文
{context}
## 请执行以下步骤
### 1. 分析结果
分析每个子任务的执行结果。
### 2. 识别关键信息
从结果中提取关键信息。
### 3. 汇总输出
将所有结果整合成一个连贯的最终输出。
## 输出要求
- 提供清晰、完整的最终结果
- 标注每个部分的来源
- 确保结果解决了原始任务
"""

View File

@@ -0,0 +1,262 @@
"""
Supervisor Agent - 负责任务规划和分发
"""
import json
import re
from typing import Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.messages import HumanMessage, SystemMessage
from .types import AgentState, TaskItem, AgentType, SupervisorDecision
from .prompts import SUPERVISOR_SYSTEM_PROMPT
class SupervisorAgent:
"""Supervisor Agent - 负责任务规划和分发"""
def __init__(
self,
llm: BaseChatModel,
max_iterations: int = 3,
max_tasks: int = 10
):
self.llm = llm
self.max_iterations = max_iterations
self.max_tasks = max_tasks
def create_node(self):
"""创建 LangGraph 节点"""
return self._supervisor_node
async def _supervisor_node(self, state: AgentState) -> dict:
"""Supervisor 节点逻辑"""
# 首次调用:分析任务并生成计划
if not state.get("task_plan"):
decision = await self._plan_tasks(
task=state["original_task"],
progress="这是任务的开始",
context=state.get("shared_context", {})
)
return {
"task_plan": decision.task_plan,
"next_node": decision.next_worker,
"current_task_index": 0,
"shared_context": {
**state.get("shared_context", {}),
"task_analysis": decision.analysis
}
}
# 非首次调用:检查任务状态,决定下一步
current_task_index = state.get("current_task_index", 0)
task_plan = state.get("task_plan", [])
# 获取当前任务
if current_task_index >= len(task_plan):
# 所有任务完成,进入汇总
return {
"next_node": "aggregator",
"shared_context": state.get("shared_context", {})
}
current_task = task_plan[current_task_index]
# 检查当前任务状态
if current_task.status == "completed":
# 当前任务完成,检查是否还有更多任务
if current_task_index + 1 < len(task_plan):
next_index = current_task_index + 1
next_task = task_plan[next_index]
return {
"current_task_index": next_index,
"next_node": next_task.assigned_agent,
"iteration": state.get("iteration", 0),
"shared_context": state.get("shared_context", {})
}
else:
# 所有任务完成,进入汇总
return {
"next_node": "aggregator",
"shared_context": state.get("shared_context", {})
}
elif current_task.status == "failed":
# 任务失败,检查是否超过最大重试
if current_task.retry_count >= self.max_iterations:
# 超过最大重试,进入汇总(标记失败)
return {
"next_node": "aggregator",
"status": "failed",
"shared_context": state.get("shared_context", {})
}
else:
# 重试当前任务
return {
"next_node": current_task.assigned_agent,
"iteration": state.get("iteration", 0) + 1,
"shared_context": state.get("shared_context", {})
}
elif current_task.status == "needs_retry":
# 需要重试(来自 review
return {
"next_node": current_task.assigned_agent,
"iteration": state.get("iteration", 0) + 1,
"shared_context": state.get("shared_context", {})
}
# 默认继续执行
return {
"next_node": state.get("next_node", "aggregator"),
"shared_context": state.get("shared_context", {})
}
async def _plan_tasks(self, task: str, progress: str, context: dict) -> SupervisorDecision:
"""调用 LLM 生成任务计划"""
# 格式化 prompt
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else ""
prompt = SUPERVISOR_SYSTEM_PROMPT.format(
task=task,
progress=progress,
context=context_str
)
# 调用 LLM
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请分析任务并制定执行计划。")
])
# 解析 LLM 输出
decision = self._parse_response(response.content, task)
return decision
def _parse_response(self, response: str, original_task: str) -> SupervisorDecision:
"""解析 LLM 响应为结构化决策"""
try:
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
data = json.loads(json_match.group())
else:
raise ValueError("No JSON found")
# 解析任务计划
task_plan = []
for i, item in enumerate(data.get("task_plan", [])):
task = TaskItem(
id=item.get("id", f"task_{i+1}"),
description=item.get("description", ""),
assigned_agent=AgentType(item.get("assigned_agent", "coder")),
status="pending"
)
task_plan.append(task)
# 确定下一个 Worker
next_worker = data.get("next_worker", "research")
if isinstance(next_worker, dict):
next_worker = next_worker.get("assigned_agent", "research")
return SupervisorDecision(
analysis=data.get("analysis", "任务分析"),
task_plan=task_plan,
need_aggregation=data.get("need_aggregation", True),
next_worker=AgentType(next_worker)
)
except Exception as e:
# 解析失败,创建默认计划
return self._create_default_plan(original_task)
def _create_default_plan(self, task: str) -> SupervisorDecision:
"""创建默认任务计划"""
task_lower = task.lower()
# 根据任务关键词判断
if any(keyword in task_lower for keyword in ["搜索", "查找", "调研", "研究", "research", "search"]):
assigned_agent = AgentType.RESEARCH
elif any(keyword in task_lower for keyword in ["代码", "", "开发", "code", "program", "写代码"]):
assigned_agent = AgentType.CODER
else:
assigned_agent = AgentType.CODER
# 创建默认任务
task_item = TaskItem(
id="task_1",
description=task,
assigned_agent=assigned_agent,
status="pending"
)
return SupervisorDecision(
analysis="简单任务,直接分配给合适的 Agent 执行",
task_plan=[task_item],
need_aggregation=True,
next_worker=assigned_agent
)
class ResultAggregator:
"""结果聚合器 - 汇总多个任务的结果"""
def __init__(self, llm: BaseChatModel):
self.llm = llm
def create_node(self):
"""创建 LangGraph 节点"""
return self._aggregate_node
async def _aggregate_node(self, state: AgentState) -> dict:
"""聚合节点逻辑"""
# 准备任务计划和结果
task_plan = state.get("task_plan", [])
results = state.get("results", {})
original_task = state.get("original_task", "")
# 构建任务描述
task_descriptions = []
for task in task_plan:
task_descriptions.append(f"- {task.id}: {task.description} -> {task.status}")
# 构建结果描述
result_items = []
for task_id, result in results.items():
if isinstance(result, dict):
content = result.get("content", str(result))
else:
content = str(result)
result_items.append(f"## {task_id}\n{content}")
# 调用 LLM 汇总结果
from .prompts import AGGREGATOR_SYSTEM_PROMPT
context_str = json.dumps(state.get("shared_context", {}), ensure_ascii=False, indent=2)
prompt = AGGREGATOR_SYSTEM_PROMPT.format(
original_task=original_task,
task_plan="\n".join(task_descriptions),
results="\n\n".join(result_items) if result_items else "无结果",
context=context_str
)
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请汇总以上结果,给出最终输出。")
])
# 检查是否有失败的任务
has_failed = any(task.status == "failed" for task in task_plan)
return {
"final_output": response.content,
"status": "failed" if has_failed else "completed",
"next_node": "__end__"
}

View File

@@ -0,0 +1,93 @@
"""
多智能体系统数据类型定义
"""
from typing import TypedDict, Annotated, Optional, Literal
from operator import add
from pydantic import BaseModel, Field
from enum import Enum
class TaskStatus(str, Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
NEEDS_RETRY = "needs_retry"
class AgentType(str, Enum):
"""Agent 类型"""
SUPERVISOR = "supervisor"
RESEARCH = "research"
CODER = "coder"
REVIEW = "review"
AGGREGATOR = "aggregator"
class TaskItem(BaseModel):
"""单个任务项"""
id: str = Field(..., description="任务唯一标识")
description: str = Field(..., description="任务描述")
assigned_agent: AgentType = Field(..., description="分配的 Agent 类型")
status: TaskStatus = Field(default=TaskStatus.PENDING, description="任务状态")
result: Optional[dict] = Field(default=None, description="任务执行结果")
error: Optional[str] = Field(default=None, description="错误信息")
retry_count: int = Field(default=0, description="重试次数")
class SupervisorDecision(BaseModel):
"""Supervisor 的结构化决策"""
analysis: str = Field(..., description="任务分析")
task_plan: list[TaskItem] = Field(..., description="任务计划")
need_aggregation: bool = Field(default=True, description="是否需要汇总")
next_worker: AgentType = Field(..., description="下一个执行的 Worker")
class ReviewResult(BaseModel):
"""Review 结果"""
passed: bool = Field(..., description="是否通过")
issues: list[dict] = Field(default_factory=list, description="问题列表")
suggestions: list[str] = Field(default_factory=list, description="改进建议")
retryable: bool = Field(default=True, description="是否可重试")
class AgentState(TypedDict):
"""贯穿整个图的 Agent 状态"""
# 用户输入
original_task: str # 原始任务描述
session_id: Optional[str] # 会话 ID
# 任务规划
task_plan: list[TaskItem] # 分解后的任务列表
current_task_index: int # 当前执行的任务索引
# 执行结果
results: dict # {task_id: result}
# 流程控制
iteration: int # 当前迭代次数
next_node: str # 下一个节点名称
# 共享上下文
shared_context: dict # Agent 间共享的数据
# 最终输出
final_output: str
status: Literal["running", "completed", "failed"] # 运行状态
def create_initial_state(task: str, session_id: str = None) -> AgentState:
"""创建初始状态"""
return {
"original_task": task,
"session_id": session_id,
"task_plan": [],
"current_task_index": 0,
"results": {},
"iteration": 0,
"next_node": "supervisor",
"shared_context": {},
"final_output": "",
"status": "running"
}

View File

@@ -0,0 +1,14 @@
"""
Worker Agents
"""
from .base import BaseWorker
from .research import ResearchWorker
from .coder import CoderWorker
from .review import ReviewWorker
__all__ = [
"BaseWorker",
"ResearchWorker",
"CoderWorker",
"ReviewWorker",
]

View File

@@ -0,0 +1,138 @@
"""
Worker Agent 基类
"""
import json
from abc import ABC, abstractmethod
from typing import Any, Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from ..types import AgentState, TaskItem, TaskStatus
class BaseWorker(ABC):
"""Worker Agent 基类"""
def __init__(
self,
llm: BaseChatModel,
name: str,
system_prompt: str,
tools: list = None,
tool_registry=None
):
self.llm = llm
self.name = name
self.system_prompt = system_prompt
self.tools = tools or []
self.tool_registry = tool_registry
@abstractmethod
async def execute(self, task: TaskItem, context: dict) -> dict:
"""
执行任务
Returns:
dict: {
"success": bool,
"content": str,
"context": dict, # 更新共享上下文
"error": str (optional)
}
"""
pass
def create_node(self):
"""创建 LangGraph 节点"""
async def node(state: AgentState) -> dict:
task_index = state.get("current_task_index", 0)
task_plan = state.get("task_plan", [])
if task_index >= len(task_plan):
return {"next_node": "aggregator"}
task = task_plan[task_index]
shared_context = state.get("shared_context", {})
# 更新任务状态为 running
updated_plan = self._update_task_status(task_plan, task.id, TaskStatus.RUNNING)
try:
# 执行任务
result = await self.execute(task, shared_context)
# 更新任务状态
if result.get("success"):
updated_plan = self._update_task_status(
updated_plan,
task.id,
TaskStatus.COMPLETED,
result=result.get("content", "")
)
else:
updated_plan = self._update_task_status(
updated_plan,
task.id,
TaskStatus.FAILED,
error=result.get("error", "Unknown error")
)
# 构建新上下文
new_context = {**shared_context, **(result.get("context", {}))}
return {
"task_plan": updated_plan,
"results": {**state.get("results", {}), task.id: result},
"shared_context": new_context,
"next_node": "review"
}
except Exception as e:
# 执行出错
updated_plan = self._update_task_status(
updated_plan,
task.id,
TaskStatus.FAILED,
error=str(e)
)
return {
"task_plan": updated_plan,
"results": {**state.get("results", {}), task.id: {"error": str(e)}},
"next_node": "review"
}
return node
def _update_task_status(
self,
tasks: list,
task_id: str,
status: TaskStatus,
result: Any = None,
error: str = None
) -> list:
"""更新任务状态"""
return [
{
**task.model_dump() if hasattr(task, 'model_dump') else task,
"status": status,
"result": result,
"error": error
}
for task in tasks
]
def _build_messages(self, task: str, context: dict) -> list:
"""构建消息列表"""
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else ""
user_prompt = self.system_prompt.format(
task=task,
context=context_str
)
return [
SystemMessage(content=user_prompt),
HumanMessage(content=task)
]

View File

@@ -0,0 +1,146 @@
"""
Coder Worker - 代码编写和修改
"""
import json
from langchain_core.language_models import BaseChatModel
from .base import BaseWorker
from ..types import TaskItem
from ..prompts import CODER_SYSTEM_PROMPT
class CoderWorker(BaseWorker):
"""Coder Worker - 代码编写和修改"""
def __init__(
self,
llm: BaseChatModel,
tool_registry=None,
tools: list = None
):
super().__init__(
llm=llm,
name="coder",
system_prompt=CODER_SYSTEM_PROMPT,
tools=tools or [],
tool_registry=tool_registry
)
async def execute(self, task: TaskItem, context: dict) -> dict:
"""执行编码任务"""
# 构建消息
messages = self._build_messages(task.description, context)
# 如果有代码执行工具,启用它
if self.tool_registry:
tool_defs = self._get_available_tools()
if tool_defs:
try:
response = await self.llm.agenerate(
messages=messages,
tools=tool_defs
)
return self._handle_tool_response(response, messages)
except Exception:
# 如果工具调用失败,回退到普通调用
pass
# 普通调用
try:
response = await self.llm.ainvoke(messages)
content = response.content if hasattr(response, 'content') else str(response)
return {
"success": True,
"content": content,
"context": {
"code_written": True,
"last_coder": self.name
}
}
except Exception as e:
return {
"success": False,
"content": "",
"error": str(e),
"context": {}
}
def _get_available_tools(self) -> list:
"""获取可用工具定义"""
if not self.tool_registry:
return []
tool_names = self.tools or ["search", "execute_code"]
tool_defs = []
for tool_name in tool_names:
tool_def = self.tool_registry.get_tool_definition(tool_name)
if tool_def:
tool_defs.append(tool_def)
return tool_defs
def _handle_tool_response(self, response, original_messages: list) -> dict:
"""处理工具调用响应"""
# 简化实现
response_message = response.generations[0][0]
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
# 有工具调用
tool_results = []
for tool_call in response_message.tool_calls:
tool_name = tool_call.name
tool_args = tool_call.arguments
# 执行工具
try:
tool_func, _ = self.tool_registry.get_tool(tool_name)
result = tool_func(**tool_args)
tool_results.append({
"tool": tool_name,
"result": str(result)
})
except Exception as e:
tool_results.append({
"tool": tool_name,
"error": str(e)
})
# 将工具结果添加到消息
for msg in response.generations[0]:
original_messages.append(msg)
for tool_result in tool_results:
original_messages.append({
"role": "tool",
"content": json.dumps(tool_result, ensure_ascii=False)
})
# 再次调用 LLM 生成最终响应
final_response = await self.llm.ainvoke(original_messages)
content = final_response.content if hasattr(final_response, 'content') else str(final_response)
return {
"success": True,
"content": content,
"context": {
"code_written": True,
"tool_results": tool_results,
"last_coder": self.name
}
}
else:
# 无工具调用
content = response_message.text if hasattr(response_message, 'text') else str(response_message)
return {
"success": True,
"content": content,
"context": {
"code_written": True,
"last_coder": self.name
}
}

View File

@@ -0,0 +1,70 @@
"""
Research Worker - 信息搜索和调研
"""
import json
from langchain_core.language_models import BaseChatModel
from .base import BaseWorker
from ..types import TaskItem
from ..prompts import RESEARCH_SYSTEM_PROMPT
class ResearchWorker(BaseWorker):
"""Research Worker - 信息搜索和调研"""
def __init__(
self,
llm: BaseChatModel,
tool_registry=None,
tools: list = None
):
super().__init__(
llm=llm,
name="research",
system_prompt=RESEARCH_SYSTEM_PROMPT,
tools=tools or [],
tool_registry=tool_registry
)
async def execute(self, task: TaskItem, context: dict) -> dict:
"""执行调研任务"""
# 构建消息
messages = self._build_messages(task.description, context)
try:
# 调用 LLM
response = await self.llm.ainvoke(messages)
content = response.content if hasattr(response, 'content') else str(response)
# 尝试提取搜索结果
search_results = self._extract_search_results(content)
return {
"success": True,
"content": content,
"context": {
"research_results": search_results,
"last_research_by": self.name
}
}
except Exception as e:
return {
"success": False,
"content": "",
"error": str(e),
"context": {}
}
def _extract_search_results(self, content: str) -> list:
"""从内容中提取搜索结果"""
# 简单实现:查找以 - 或 * 开头的行
results = []
for line in content.split('\n'):
line = line.strip()
if line.startswith(('- ', '* ', '1. ', '2. ', '3. ')):
results.append(line.lstrip('-*123. '))
return results[:10] # 限制数量

View File

@@ -0,0 +1,174 @@
"""
Review Worker - 结果检查和质量评审
"""
import json
import re
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from .base import BaseWorker
from ..types import AgentState, TaskItem, TaskStatus, ReviewResult
from ..prompts import REVIEW_SYSTEM_PROMPT
class ReviewWorker(BaseWorker):
"""Review Worker - 结果检查和质量评审"""
def __init__(
self,
llm: BaseChatModel,
tool_registry=None,
tools: list = None
):
super().__init__(
llm=llm,
name="review",
system_prompt=REVIEW_SYSTEM_PROMPT,
tools=tools or [],
tool_registry=tool_registry
)
async def execute(self, task: TaskItem, context: dict) -> dict:
"""执行评审任务"""
# 获取当前任务索引和任务计划
# 注意:这里需要从 context 中获取更多信息
# 构建 prompt
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else ""
prompt = REVIEW_SYSTEM_PROMPT.format(
original_task=context.get("original_task", ""),
task_description=task.description,
execution_result=task.result if task.result else "无结果",
context=context_str
)
try:
# 调用 LLM 进行评审
response = await self.llm.ainvoke([
SystemMessage(content=prompt),
HumanMessage(content="请评审以上执行结果。")
])
# 解析评审结果
review_result = self._parse_review_response(response.content)
# 根据评审结果决定下一步
if review_result.passed:
# 通过,更新任务状态为 completed
new_status = TaskStatus.COMPLETED
next_node = "supervisor" # 返回 Supervisor 继续执行
else:
# 未通过,检查是否可重试
if review_result.retryable:
new_status = TaskStatus.NEEDS_RETRY
next_node = "supervisor" # 返回 Supervisor 决定是否重试
else:
new_status = TaskStatus.FAILED
next_node = "aggregator" # 失败,进入汇总
return {
"success": review_result.passed,
"content": response.content,
"review_result": review_result.model_dump() if hasattr(review_result, 'model_dump') else dict(review_result),
"context": {
"review_passed": review_result.passed,
"issues": review_result.issues,
"last_review_by": self.name
}
}
except Exception as e:
return {
"success": False,
"content": "",
"error": str(e),
"context": {}
}
def _parse_review_response(self, response: str) -> ReviewResult:
"""解析评审响应"""
try:
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', response)
if json_match:
data = json.loads(json_match.group())
else:
raise ValueError("No JSON found")
return ReviewResult(
passed=data.get("passed", True),
issues=data.get("issues", []),
suggestions=data.get("suggestions", []),
retryable=data.get("retryable", True)
)
except Exception:
# 解析失败,默认通过
return ReviewResult(
passed=True,
issues=[],
suggestions=[],
retryable=True
)
def create_node(self):
"""创建 LangGraph 节点"""
async def node(state: AgentState) -> dict:
task_index = state.get("current_task_index", 0)
task_plan = state.get("task_plan", [])
if task_index >= len(task_plan):
return {"next_node": "aggregator"}
task = task_plan[task_index]
shared_context = {
**state.get("shared_context", {}),
"original_task": state.get("original_task", "")
}
try:
# 执行评审
result = await self.execute(task, shared_context)
# 更新任务状态
review_passed = result.get("review_result", {}).get("passed", True)
retryable = result.get("review_result", {}).get("retryable", True)
if review_passed:
updated_status = TaskStatus.COMPLETED
elif retryable:
updated_status = TaskStatus.NEEDS_RETRY
else:
updated_status = TaskStatus.FAILED
updated_plan = self._update_task_status(
task_plan,
task.id,
updated_status,
result=task.result
)
# 确定下一步
if updated_status == TaskStatus.COMPLETED:
next_node = "supervisor"
elif updated_status == TaskStatus.NEEDS_RETRY:
next_node = "supervisor"
else:
next_node = "aggregator"
return {
"task_plan": updated_plan,
"results": {**state.get("results", {}), f"{task.id}_review": result},
"shared_context": {**shared_context, **result.get("context", {})},
"next_node": next_node
}
except Exception as e:
return {
"next_node": "aggregator",
"results": {**state.get("results", {}), f"{task.id}_review": {"error": str(e)}}
}
return node