226 lines
6.8 KiB
Python
226 lines
6.8 KiB
Python
|
|
"""Team agent for multi-agent collaboration."""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class TeamAgent:
|
||
|
|
"""Team agent that manages multiple agents for collaborative problem solving.
|
||
|
|
|
||
|
|
Supports different strategies:
|
||
|
|
- parallel: All agents respond in parallel, results are aggregated
|
||
|
|
- sequential: Agents respond one by one in sequence
|
||
|
|
- supervisor: A supervisor agent coordinates the work
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, provider: Any, model: str, workspace: Any):
|
||
|
|
"""Initialize the team agent.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
provider: LLM provider
|
||
|
|
model: Model name to use
|
||
|
|
workspace: Workspace path
|
||
|
|
"""
|
||
|
|
self.provider = provider
|
||
|
|
self.model = model
|
||
|
|
self.workspace = workspace
|
||
|
|
|
||
|
|
async def chat(
|
||
|
|
self,
|
||
|
|
message: str,
|
||
|
|
session_id: str = "default",
|
||
|
|
supervisor_agent_id: int = 0,
|
||
|
|
member_agent_ids: list[int] | None = None,
|
||
|
|
strategy: str = "parallel",
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
"""Process a team chat message.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
message: User message
|
||
|
|
session_id: Session identifier
|
||
|
|
supervisor_agent_id: Supervisor agent ID (for future use)
|
||
|
|
member_agent_ids: List of member agent IDs to involve
|
||
|
|
strategy: Collaboration strategy (parallel/sequential/supervisor)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dict with response and subtask_results
|
||
|
|
"""
|
||
|
|
member_agent_ids = member_agent_ids or []
|
||
|
|
|
||
|
|
logger.info(f"Team chat: strategy={strategy}, members={member_agent_ids}, message={message[:50]}...")
|
||
|
|
|
||
|
|
if strategy == "parallel":
|
||
|
|
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||
|
|
elif strategy == "sequential":
|
||
|
|
return await self._sequential_chat(message, member_agent_ids, session_id)
|
||
|
|
else:
|
||
|
|
# Default to parallel
|
||
|
|
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||
|
|
|
||
|
|
async def _parallel_chat(
|
||
|
|
self,
|
||
|
|
message: str,
|
||
|
|
member_agent_ids: list[int],
|
||
|
|
session_id: str,
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
"""Execute parallel chat with multiple agents.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
message: User message
|
||
|
|
member_agent_ids: List of member agent IDs
|
||
|
|
session_id: Session identifier
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Aggregated response from all agents
|
||
|
|
"""
|
||
|
|
if not member_agent_ids:
|
||
|
|
return {
|
||
|
|
"response": "No member agents specified for team chat.",
|
||
|
|
"subtask_results": [],
|
||
|
|
}
|
||
|
|
|
||
|
|
# Create tasks for each agent
|
||
|
|
tasks = []
|
||
|
|
for agent_id in member_agent_ids:
|
||
|
|
task = self._call_agent(agent_id, message, session_id)
|
||
|
|
tasks.append(task)
|
||
|
|
|
||
|
|
# Execute all tasks in parallel
|
||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
|
|
||
|
|
# Aggregate results
|
||
|
|
subtask_results = []
|
||
|
|
responses = []
|
||
|
|
|
||
|
|
for i, result in enumerate(results):
|
||
|
|
agent_id = member_agent_ids[i]
|
||
|
|
|
||
|
|
if isinstance(result, Exception):
|
||
|
|
error_msg = f"Agent {agent_id} error: {str(result)}"
|
||
|
|
logger.error(error_msg)
|
||
|
|
subtask_results.append({
|
||
|
|
"agent_id": agent_id,
|
||
|
|
"status": "error",
|
||
|
|
"result": str(result),
|
||
|
|
})
|
||
|
|
else:
|
||
|
|
subtask_results.append({
|
||
|
|
"agent_id": agent_id,
|
||
|
|
"status": "success",
|
||
|
|
"result": result,
|
||
|
|
})
|
||
|
|
responses.append(result)
|
||
|
|
|
||
|
|
# Combine responses
|
||
|
|
if responses:
|
||
|
|
combined_response = self._aggregate_responses(responses)
|
||
|
|
else:
|
||
|
|
combined_response = "All agents failed to respond."
|
||
|
|
|
||
|
|
return {
|
||
|
|
"response": combined_response,
|
||
|
|
"subtask_results": subtask_results,
|
||
|
|
}
|
||
|
|
|
||
|
|
async def _sequential_chat(
|
||
|
|
self,
|
||
|
|
message: str,
|
||
|
|
member_agent_ids: list[int],
|
||
|
|
session_id: str,
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
"""Execute sequential chat with multiple agents.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
message: User message
|
||
|
|
member_agent_ids: List of member agent IDs
|
||
|
|
session_id: Session identifier
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Aggregated response from all agents
|
||
|
|
"""
|
||
|
|
if not member_agent_ids:
|
||
|
|
return {
|
||
|
|
"response": "No member agents specified for team chat.",
|
||
|
|
"subtask_results": [],
|
||
|
|
}
|
||
|
|
|
||
|
|
subtask_results = []
|
||
|
|
responses = []
|
||
|
|
|
||
|
|
for agent_id in member_agent_ids:
|
||
|
|
try:
|
||
|
|
result = await self._call_agent(agent_id, message, session_id)
|
||
|
|
subtask_results.append({
|
||
|
|
"agent_id": agent_id,
|
||
|
|
"status": "success",
|
||
|
|
"result": result,
|
||
|
|
})
|
||
|
|
responses.append(result)
|
||
|
|
except Exception as e:
|
||
|
|
error_msg = f"Agent {agent_id} error: {str(e)}"
|
||
|
|
logger.error(error_msg)
|
||
|
|
subtask_results.append({
|
||
|
|
"agent_id": agent_id,
|
||
|
|
"status": "error",
|
||
|
|
"result": str(e),
|
||
|
|
})
|
||
|
|
|
||
|
|
# Combine responses
|
||
|
|
if responses:
|
||
|
|
combined_response = self._aggregate_responses(responses)
|
||
|
|
else:
|
||
|
|
combined_response = "All agents failed to respond."
|
||
|
|
|
||
|
|
return {
|
||
|
|
"response": combined_response,
|
||
|
|
"subtask_results": subtask_results,
|
||
|
|
}
|
||
|
|
|
||
|
|
async def _call_agent(
|
||
|
|
self,
|
||
|
|
agent_id: int,
|
||
|
|
message: str,
|
||
|
|
session_id: str,
|
||
|
|
) -> str:
|
||
|
|
"""Call an individual agent.
|
||
|
|
|
||
|
|
For now, this is a placeholder that simulates agent responses.
|
||
|
|
In a real implementation, this would call the actual agent.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
agent_id: Agent ID
|
||
|
|
message: User message
|
||
|
|
session_id: Session identifier
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Agent response
|
||
|
|
"""
|
||
|
|
# Simulate agent processing delay
|
||
|
|
await asyncio.sleep(0.5)
|
||
|
|
|
||
|
|
# Return a simulated response
|
||
|
|
return f"Agent {agent_id} processed: {message[:30]}..."
|
||
|
|
|
||
|
|
def _aggregate_responses(self, responses: list[str]) -> str:
|
||
|
|
"""Aggregate multiple agent responses into a single response.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
responses: List of individual agent responses
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Combined response
|
||
|
|
"""
|
||
|
|
if len(responses) == 1:
|
||
|
|
return responses[0]
|
||
|
|
|
||
|
|
header = f"【团队协作结果】共 {len(responses)} 位智能体参与了讨论:\n\n"
|
||
|
|
body = ""
|
||
|
|
|
||
|
|
for i, resp in enumerate(responses, 1):
|
||
|
|
body += f"--- 智能体 {i} ---\n{resp}\n\n"
|
||
|
|
|
||
|
|
return header + body
|