509 lines
14 KiB
Python
509 lines
14 KiB
Python
"""Agent 协调整器 - Phase 10.5
|
||
|
||
统一协调所有 Agent 组件:TeamLeader, RemoteTransport, BackgroundTaskManager, SessionManager
|
||
"""
|
||
|
||
from typing import Any
|
||
|
||
from app.agents.background.manager import BackgroundTaskManager, get_background_task_manager
|
||
from app.agents.session.manager import AgentSession, create_agent_session, get_agent_session
|
||
from app.agents.team.leader import TeamLeader
|
||
from app.agents.transport.remote import RemoteTransport
|
||
|
||
|
||
class AgentCoordinator:
|
||
"""Agent 协调整器
|
||
|
||
统一协调所有 Agent 组件,提供单一入口处理各类 Agent 操作。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
background_manager: BackgroundTaskManager | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
background_manager: 后台任务管理器,None 则使用全局单例
|
||
"""
|
||
self._team_leaders: dict[str, TeamLeader] = {}
|
||
self._remote_transport = RemoteTransport()
|
||
self._background_manager = background_manager or get_background_task_manager()
|
||
self._sessions: dict[str, AgentSession] = {}
|
||
|
||
# === Team 协作方法 ===
|
||
|
||
def create_team(self, team_id: str, members: list[str]) -> dict[str, Any]:
|
||
"""创建团队
|
||
|
||
Args:
|
||
team_id: 团队 ID
|
||
members: 成员 ID 列表
|
||
|
||
Returns:
|
||
团队创建结果
|
||
"""
|
||
if team_id in self._team_leaders:
|
||
return {"status": "error", "message": f"Team '{team_id}' already exists"}
|
||
|
||
leader = TeamLeader(team_id=team_id, members=members)
|
||
self._team_leaders[team_id] = leader
|
||
return {
|
||
"status": "created",
|
||
"team_id": team_id,
|
||
"members": members,
|
||
}
|
||
|
||
def get_team(self, team_id: str) -> TeamLeader | None:
|
||
"""获取团队
|
||
|
||
Args:
|
||
team_id: 团队 ID
|
||
|
||
Returns:
|
||
TeamLeader 或 None
|
||
"""
|
||
return self._team_leaders.get(team_id)
|
||
|
||
def assign_task(self, team_id: str, description: str, member: str) -> dict[str, Any]:
|
||
"""创建并分配任务
|
||
|
||
Args:
|
||
team_id: 团队 ID
|
||
description: 任务描述
|
||
member: 成员 ID
|
||
|
||
Returns:
|
||
分配结果
|
||
"""
|
||
leader = self._team_leaders.get(team_id)
|
||
if not leader:
|
||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||
|
||
task_id = leader.create_task(description)
|
||
success = leader.assign_task(task_id, member)
|
||
return {
|
||
"status": "assigned" if success else "error",
|
||
"task_id": task_id,
|
||
"assignee": member,
|
||
}
|
||
|
||
def broadcast_task(self, team_id: str, description: str) -> dict[str, Any]:
|
||
"""广播任务给所有成员
|
||
|
||
Args:
|
||
team_id: 团队 ID
|
||
description: 任务描述
|
||
|
||
Returns:
|
||
广播结果
|
||
"""
|
||
leader = self._team_leaders.get(team_id)
|
||
if not leader:
|
||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||
|
||
task_ids = leader.broadcast_task(description)
|
||
return {
|
||
"status": "broadcast",
|
||
"team_id": team_id,
|
||
"task_ids": task_ids,
|
||
"member_count": len(leader.members),
|
||
}
|
||
|
||
def collect_team_results(self, team_id: str) -> dict[str, Any]:
|
||
"""收集团队任务结果
|
||
|
||
Args:
|
||
team_id: 团队 ID
|
||
|
||
Returns:
|
||
收集结果
|
||
"""
|
||
leader = self._team_leaders.get(team_id)
|
||
if not leader:
|
||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||
|
||
results = leader.collect_results()
|
||
status = leader.get_team_status()
|
||
return {
|
||
"status": "collected",
|
||
"team_id": team_id,
|
||
"results": results,
|
||
"completed": status["completed"],
|
||
"failed": status["failed"],
|
||
}
|
||
|
||
def get_team_status(self, team_id: str) -> dict[str, Any]:
|
||
"""获取团队状态
|
||
|
||
Args:
|
||
team_id: 团队 ID
|
||
|
||
Returns:
|
||
团队状态
|
||
"""
|
||
leader = self._team_leaders.get(team_id)
|
||
if not leader:
|
||
return {"status": "error", "message": f"Team '{team_id}' not found"}
|
||
|
||
return leader.get_team_status()
|
||
|
||
# === 后台任务方法 ===
|
||
|
||
def submit_background_task(
|
||
self,
|
||
name: str,
|
||
coro: Any,
|
||
*args,
|
||
**kwargs,
|
||
) -> dict[str, Any]:
|
||
"""提交后台任务
|
||
|
||
Args:
|
||
name: 任务名称
|
||
coro: 协程函数
|
||
*args: 位置参数
|
||
**kwargs: 关键字参数
|
||
|
||
Returns:
|
||
提交结果
|
||
"""
|
||
task_id = self._background_manager.submit_task(name, coro, *args, **kwargs)
|
||
return {
|
||
"status": "submitted",
|
||
"task_id": task_id,
|
||
"name": name,
|
||
}
|
||
|
||
def cancel_background_task(self, task_id: str) -> dict[str, Any]:
|
||
"""取消后台任务
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
|
||
Returns:
|
||
取消结果
|
||
"""
|
||
success = self._background_manager.cancel_task(task_id)
|
||
return {
|
||
"status": "cancelled" if success else "error",
|
||
"task_id": task_id,
|
||
}
|
||
|
||
def get_background_task_status(self, task_id: str) -> dict[str, Any]:
|
||
"""获取后台任务状态
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
|
||
Returns:
|
||
任务状态
|
||
"""
|
||
task = self._background_manager.get_task_status(task_id)
|
||
if not task:
|
||
return {"status": "error", "message": f"Task '{task_id}' not found"}
|
||
|
||
return {
|
||
"status": "found",
|
||
"task_id": task.id,
|
||
"name": task.name,
|
||
"task_status": task.status.value,
|
||
"result": task.result,
|
||
"error": task.error,
|
||
}
|
||
|
||
def list_background_tasks(self) -> dict[str, Any]:
|
||
"""列出所有后台任务
|
||
|
||
Returns:
|
||
任务列表
|
||
"""
|
||
tasks = self._background_manager.list_tasks()
|
||
return {
|
||
"status": "list",
|
||
"count": len(tasks),
|
||
"tasks": [
|
||
{
|
||
"id": t.id,
|
||
"name": t.name,
|
||
"status": t.status.value,
|
||
}
|
||
for t in tasks
|
||
],
|
||
}
|
||
|
||
# === 会话方法 ===
|
||
|
||
def create_session(
|
||
self,
|
||
user_id: str | None = None,
|
||
parent_session_id: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""创建会话
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
parent_session_id: 父会话 ID
|
||
|
||
Returns:
|
||
创建结果
|
||
"""
|
||
session = create_agent_session(
|
||
user_id=user_id,
|
||
parent_session_id=parent_session_id,
|
||
)
|
||
self._sessions[session.session_id] = session
|
||
return {
|
||
"status": "created",
|
||
"session_id": session.session_id,
|
||
"user_id": user_id,
|
||
"parent_session_id": parent_session_id,
|
||
}
|
||
|
||
def get_session(self, session_id: str) -> AgentSession | None:
|
||
"""获取会话
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
|
||
Returns:
|
||
AgentSession 或 None
|
||
"""
|
||
return self._sessions.get(session_id) or get_agent_session(session_id)
|
||
|
||
async def process_session_message(
|
||
self,
|
||
session_id: str,
|
||
message: str,
|
||
response: str,
|
||
) -> dict[str, Any]:
|
||
"""处理会话消息
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
message: 用户消息
|
||
response: 助手响应
|
||
|
||
Returns:
|
||
处理结果
|
||
"""
|
||
session = self.get_session(session_id)
|
||
if not session:
|
||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||
|
||
await session.process_message(message, response)
|
||
return {
|
||
"status": "processed",
|
||
"session_id": session_id,
|
||
"message_count": session.context.message_count,
|
||
}
|
||
|
||
async def spawn_child_session(
|
||
self,
|
||
session_id: str,
|
||
user_id: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""创建子会话
|
||
|
||
Args:
|
||
session_id: 父会话 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
创建结果
|
||
"""
|
||
session = self.get_session(session_id)
|
||
if not session:
|
||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||
|
||
child = await session.spawn_child_session(user_id=user_id)
|
||
self._sessions[child.session_id] = child
|
||
return {
|
||
"status": "spawned",
|
||
"parent_session_id": session_id,
|
||
"child_session_id": child.session_id,
|
||
"depth": child.context.depth,
|
||
}
|
||
|
||
def get_session_summary(self, session_id: str) -> dict[str, Any]:
|
||
"""获取会话摘要
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
|
||
Returns:
|
||
会话摘要
|
||
"""
|
||
import asyncio
|
||
|
||
session = self.get_session(session_id)
|
||
if not session:
|
||
return {"status": "error", "message": f"Session '{session_id}' not found"}
|
||
|
||
# get_session_summary is async, so we need to run it
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
if loop.is_running():
|
||
# Create a future
|
||
future = asyncio.ensure_future(session.get_session_summary())
|
||
return {"status": "found", "summary": future}
|
||
else:
|
||
return {
|
||
"status": "found",
|
||
"summary": loop.run_until_complete(session.get_session_summary()),
|
||
}
|
||
except RuntimeError:
|
||
# No event loop, create one
|
||
return {"status": "found", "summary": asyncio.run(session.get_session_summary())}
|
||
|
||
# === 远程传输方法 ===
|
||
|
||
def register_remote_handler(self, event_type: str, handler: Any) -> None:
|
||
"""注册远程消息处理器
|
||
|
||
Args:
|
||
event_type: 事件类型
|
||
handler: 处理函数
|
||
"""
|
||
self._remote_transport.register_handler(event_type, handler)
|
||
|
||
async def send_remote_response(
|
||
self,
|
||
session_id: str,
|
||
response: dict[str, Any],
|
||
) -> bool:
|
||
"""发送远程响应
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
response: 响应数据
|
||
|
||
Returns:
|
||
是否发送成功
|
||
"""
|
||
return await self._remote_transport.send_response(session_id, response)
|
||
|
||
async def send_remote_event(
|
||
self,
|
||
session_id: str,
|
||
event: dict[str, Any],
|
||
) -> bool:
|
||
"""发送远程事件
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
event: 事件数据
|
||
|
||
Returns:
|
||
是否发送成功
|
||
"""
|
||
return await self._remote_transport.send_event(session_id, event)
|
||
|
||
async def send_remote_tool_call(
|
||
self,
|
||
session_id: str,
|
||
tool_call: dict[str, Any],
|
||
) -> bool:
|
||
"""发送远程工具调用
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
tool_call: 工具调用数据
|
||
|
||
Returns:
|
||
是否发送成功
|
||
"""
|
||
return await self._remote_transport.send_tool_call(session_id, tool_call)
|
||
|
||
# === 统一协调入口 ===
|
||
|
||
async def coordinate(self, request: dict[str, Any]) -> dict[str, Any]:
|
||
"""统一协调入口
|
||
|
||
根据请求类型协调各类 Agent 操作。
|
||
|
||
Args:
|
||
request: 请求数据,包含:
|
||
- action: 操作类型 (team_create, team_assign, task_submit, session_create, etc.)
|
||
- 其他参数根据 action 不同而不同
|
||
|
||
Returns:
|
||
协调结果
|
||
"""
|
||
action = request.get("action")
|
||
|
||
if action == "team_create":
|
||
return self.create_team(
|
||
team_id=request["team_id"],
|
||
members=request["members"],
|
||
)
|
||
|
||
elif action == "team_assign":
|
||
return self.assign_task(
|
||
team_id=request["team_id"],
|
||
description=request["description"],
|
||
member=request["member"],
|
||
)
|
||
|
||
elif action == "team_broadcast":
|
||
return self.broadcast_task(
|
||
team_id=request["team_id"],
|
||
description=request["description"],
|
||
)
|
||
|
||
elif action == "team_collect":
|
||
return self.collect_team_results(team_id=request["team_id"])
|
||
|
||
elif action == "team_status":
|
||
return self.get_team_status(team_id=request["team_id"])
|
||
|
||
elif action == "task_submit":
|
||
return self.submit_background_task(
|
||
name=request["name"],
|
||
coro=request["coro"],
|
||
*request.get("args", []),
|
||
**request.get("kwargs", {}),
|
||
)
|
||
|
||
elif action == "task_cancel":
|
||
return self.cancel_background_task(task_id=request["task_id"])
|
||
|
||
elif action == "task_status":
|
||
return self.get_background_task_status(task_id=request["task_id"])
|
||
|
||
elif action == "session_create":
|
||
return self.create_session(
|
||
user_id=request.get("user_id"),
|
||
parent_session_id=request.get("parent_session_id"),
|
||
)
|
||
|
||
elif action == "session_message":
|
||
return await self.process_session_message(
|
||
session_id=request["session_id"],
|
||
message=request["message"],
|
||
response=request["response"],
|
||
)
|
||
|
||
elif action == "session_spawn":
|
||
return await self.spawn_child_session(
|
||
session_id=request["session_id"],
|
||
user_id=request.get("user_id"),
|
||
)
|
||
|
||
elif action == "session_summary":
|
||
return self.get_session_summary(session_id=request["session_id"])
|
||
|
||
else:
|
||
return {"status": "error", "message": f"Unknown action: {action}"}
|
||
|
||
|
||
# 全局单例
|
||
_coordinator: AgentCoordinator | None = None
|
||
|
||
|
||
def get_agent_coordinator() -> AgentCoordinator:
|
||
"""获取全局 Agent 协调整器"""
|
||
global _coordinator
|
||
if _coordinator is None:
|
||
_coordinator = AgentCoordinator()
|
||
return _coordinator
|