Add Day 4 visibility endpoints and response models, strengthen collaboration/task verification behavior, and patch conversation schema startup migration for agent_state compatibility. Extend backend regression coverage for runtime schemas, verifier behavior, visibility APIs, router auth, and legacy conversation list loading.
2336 lines
94 KiB
Python
2336 lines
94 KiB
Python
"""Jarvis agent graph orchestration."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
from uuid import uuid4
|
||
from typing import Any, Literal, cast
|
||
|
||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||
from langgraph.graph import END, StateGraph
|
||
|
||
from app.agents.prompts import (
|
||
ANALYST_SYSTEM_PROMPT,
|
||
COORDINATOR_SYSTEM_PROMPT,
|
||
EXECUTOR_SYSTEM_PROMPT,
|
||
JSON_ACTION_FALLBACK_PROMPT,
|
||
LIBRARIAN_SYSTEM_PROMPT,
|
||
MASTER_SYSTEM_PROMPT,
|
||
SCHEDULE_PLANNER_SYSTEM_PROMPT,
|
||
)
|
||
from app.agents.registry import load_builtin_registry_indexes
|
||
from app.agents.schemas.event import AgentEvent
|
||
from app.agents.schemas.message import AgentMessage
|
||
from app.agents.schemas.task import AgentTask, CollaborationBudget, InterruptRecord, RecoveryRecord, TaskResult
|
||
from app.agents.skill_registry import build_skill_context
|
||
from app.agents.state import AgentRole, AgentState
|
||
from app.agents.tools import SUB_COMMANDER_TOOLSETS
|
||
from app.agents.tools.time_reasoning import normalize_tool_time_arguments
|
||
from app.agents.verifier import apply_verification_verdict, normalize_task_result, verify_task_result
|
||
from app.services.llm_service import (
|
||
create_llm_from_config,
|
||
default_provider_capabilities,
|
||
get_llm,
|
||
resolve_provider_capabilities,
|
||
)
|
||
|
||
logger = logging.getLogger("jarvis.agent")
|
||
|
||
SUB_COMMANDER_PROMPTS = {
|
||
"schedule_analysis": "你负责做日程判断与聚焦分析,只给判断和依据,不直接落库。",
|
||
"schedule_planning": "你负责把当前目标转成明确安排;当用户要求创建提醒/任务/目标时,直接调用工具执行。",
|
||
"executor_tasks": "你负责执行任务、提醒、目标、待办相关操作,需要时直接调用工具。",
|
||
"executor_forum": "你只处理论坛与指令帖相关操作。",
|
||
"librarian_retrieval": "你负责知识检索与外部搜索,基于证据回答问题。",
|
||
"librarian_graph": "你负责知识图谱上下文与结构化沉淀。",
|
||
"analyst_progress": "你负责进度研判,汇总当前推进情况。",
|
||
"analyst_insights": "你负责趋势、风险和机会判断,必要时调用检索工具。",
|
||
}
|
||
|
||
ROLE_SKILL_CONTEXT = {
|
||
AgentRole.SCHEDULE_PLANNER: "schedule_planner",
|
||
AgentRole.EXECUTOR: "executor",
|
||
AgentRole.LIBRARIAN: "librarian",
|
||
AgentRole.ANALYST: "analyst",
|
||
}
|
||
|
||
ROLE_SYSTEM_PROMPTS = {
|
||
AgentRole.SCHEDULE_PLANNER: SCHEDULE_PLANNER_SYSTEM_PROMPT,
|
||
AgentRole.EXECUTOR: EXECUTOR_SYSTEM_PROMPT,
|
||
AgentRole.LIBRARIAN: LIBRARIAN_SYSTEM_PROMPT,
|
||
AgentRole.ANALYST: ANALYST_SYSTEM_PROMPT,
|
||
}
|
||
|
||
SCHEDULE_KEYWORDS = (
|
||
"提醒",
|
||
"日程",
|
||
"安排",
|
||
"计划",
|
||
"排期",
|
||
"会议",
|
||
"开会",
|
||
"明天",
|
||
"今天",
|
||
"后天",
|
||
"下周",
|
||
"本周",
|
||
"周",
|
||
"星期",
|
||
"交付",
|
||
"节点",
|
||
"deadline",
|
||
)
|
||
ACCOUNTING_INTENT_KEYWORDS = (
|
||
"记账",
|
||
"账单",
|
||
"花了多少钱",
|
||
"用了多少钱",
|
||
"支出",
|
||
"消费",
|
||
"花销",
|
||
"开销",
|
||
)
|
||
KNOWLEDGE_KEYWORDS = ("知识", "搜索", "检索", "资料", "文档", "联网", "上网", "查询", "查一下", "最新")
|
||
GENERAL_QA_PATTERNS = (
|
||
"介绍一下",
|
||
"介绍下",
|
||
"什么是",
|
||
"是谁",
|
||
"在哪里",
|
||
"为什么",
|
||
"怎么理解",
|
||
"聊聊",
|
||
)
|
||
ANALYSIS_KEYWORDS = ("分析", "报告", "统计", "趋势", "风险", "洞察", "总结")
|
||
EXECUTION_KEYWORDS = ("创建", "更新", "修改", "执行", "发帖", "论坛", "帖子", "完成", "处理")
|
||
SCHEDULE_ANALYSIS_KEYWORDS = ("聚焦", "判断", "分析", "优先级", "取舍", "最近对话", "该做什么")
|
||
SCHEDULE_PLANNING_KEYWORDS = ("安排", "计划", "排期", "提醒", "创建", "新增", "会议", "交付", "节点")
|
||
IDENTITY_PATTERNS = ("你是谁", "你是誰")
|
||
CAPABILITY_PATTERNS = ("你能做什么", "你可以做什么", "你会做什么")
|
||
SHORT_CONFIRMATION_PATTERNS = ("创建", "好的创建", "确认创建", "就创建", "那就创建")
|
||
SCHEDULE_CONFIRMATION_HINTS = (
|
||
"创建这条提醒",
|
||
"现在创建这条提醒",
|
||
"现在创建提醒",
|
||
"是否需要我现在创建",
|
||
)
|
||
SCHEDULE_CONFIRMATION_QUESTION_MARKERS = ("是否", "要不要", "吗", "?", "?")
|
||
COLLABORATION_STEP_MARKERS = ("然后", "再", "并且", "同时", "顺便", "最后", "分别", "拆成", "协作", "整合")
|
||
COLLABORATION_ROLE_ORDER = {
|
||
AgentRole.LIBRARIAN: 0,
|
||
AgentRole.ANALYST: 1,
|
||
AgentRole.SCHEDULE_PLANNER: 2,
|
||
AgentRole.EXECUTOR: 3,
|
||
}
|
||
def _role_value(role: AgentRole | str | None) -> str:
|
||
if isinstance(role, AgentRole):
|
||
return role.value
|
||
return str(role or "")
|
||
|
||
|
||
def _normalize_current_agent(value: AgentRole | str | None) -> str:
|
||
role_value = _role_value(value)
|
||
return role_value or AgentRole.MASTER.value
|
||
|
||
|
||
def _normalize_active_agents(values: list[AgentRole | str] | None) -> list[AgentRole]:
|
||
normalized: list[AgentRole] = []
|
||
for value in values or [AgentRole.MASTER]:
|
||
role_value = _role_value(value)
|
||
try:
|
||
role = AgentRole(role_value)
|
||
except ValueError:
|
||
continue
|
||
if role not in normalized:
|
||
normalized.append(role)
|
||
return normalized or [AgentRole.MASTER]
|
||
|
||
|
||
def _normalize_user_text(text: str) -> str:
|
||
normalized = (text or "").strip().lower()
|
||
normalized = re.sub(r"[,。!?;:,.!?;:\s]+", "", normalized)
|
||
return normalized
|
||
|
||
|
||
def _stringify_message_content(content: Any) -> str:
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for item in content:
|
||
if isinstance(item, str):
|
||
parts.append(item)
|
||
continue
|
||
if isinstance(item, dict):
|
||
text = item.get("text")
|
||
if text:
|
||
parts.append(str(text))
|
||
continue
|
||
nested = item.get("content")
|
||
if nested:
|
||
parts.append(_stringify_message_content(nested))
|
||
continue
|
||
parts.append(str(item))
|
||
return "".join(parts)
|
||
if isinstance(content, dict):
|
||
text = content.get("text")
|
||
if text:
|
||
return str(text)
|
||
nested = content.get("content")
|
||
if nested:
|
||
return _stringify_message_content(nested)
|
||
return json.dumps(content, ensure_ascii=False)
|
||
return str(content)
|
||
|
||
|
||
def _get_state_int(state: AgentState, key: str) -> int:
|
||
value = state.get(key)
|
||
return value if isinstance(value, int) else 0
|
||
|
||
|
||
def _role_values() -> set[str]:
|
||
return {role.value for role in AgentRole}
|
||
|
||
|
||
def _summary_state_key(target: str) -> Literal["schedule_context_summary", "knowledge_context", "analysis_report"]:
|
||
if target not in {"schedule_context_summary", "knowledge_context", "analysis_report"}:
|
||
raise ValueError(f"unsupported summary target: {target}")
|
||
return cast(Literal["schedule_context_summary", "knowledge_context", "analysis_report"], target)
|
||
|
||
|
||
def _ensure_message_thread(state: AgentState) -> str:
|
||
thread_id = str(state.get("thread_id") or "").strip()
|
||
if thread_id:
|
||
return thread_id
|
||
thread_id = f"thread-{uuid4()}"
|
||
state["thread_id"] = thread_id
|
||
return thread_id
|
||
|
||
|
||
def _bump_message_sequence(state: AgentState) -> tuple[str, int]:
|
||
sequence = int(state.get("message_sequence") or 0) + 1
|
||
state["message_sequence"] = sequence
|
||
message_id = f"msg-{sequence}"
|
||
state["last_message_id"] = message_id
|
||
return message_id, sequence
|
||
|
||
|
||
def _task_payload(task: AgentTask | dict[str, Any] | None) -> dict[str, Any]:
|
||
if isinstance(task, AgentTask):
|
||
return task.model_dump(mode="json")
|
||
return dict(task or {})
|
||
|
||
|
||
def _budget_snapshot_from_state(
|
||
state: AgentState,
|
||
*,
|
||
mode: Literal["direct", "collaboration"],
|
||
remaining_parallel_tasks: int | None = None,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> dict[str, Any]:
|
||
metadata_payload = {
|
||
"max_spawn_depth": 1,
|
||
"max_child_agents": remaining_parallel_tasks,
|
||
"max_messages_per_thread": 24,
|
||
"max_messages_per_turn": 8,
|
||
"max_parallel_collaborators": remaining_parallel_tasks,
|
||
"recovery_attempt_limit": _get_state_int(state, "max_retries") or 1,
|
||
"current_depth": _get_state_int(state, "collaboration_depth"),
|
||
**(metadata or {}),
|
||
}
|
||
budget = CollaborationBudget(
|
||
mode=mode,
|
||
max_parallel_tasks=remaining_parallel_tasks,
|
||
remaining_parallel_tasks=remaining_parallel_tasks,
|
||
max_tool_calls=_get_state_int(state, "max_tool_rounds") or None,
|
||
remaining_tool_calls=max(_get_state_int(state, "max_tool_rounds") - _get_state_int(state, "tool_round_count"), 0)
|
||
if _get_state_int(state, "max_tool_rounds")
|
||
else None,
|
||
max_iterations=_get_state_int(state, "max_iterations") or None,
|
||
remaining_iterations=max(_get_state_int(state, "max_iterations") - _get_state_int(state, "iteration_count"), 0)
|
||
if _get_state_int(state, "max_iterations")
|
||
else None,
|
||
metadata=metadata_payload,
|
||
)
|
||
return budget.model_dump(mode="json")
|
||
|
||
|
||
def _store_budget_snapshot(state: AgentState, budget_snapshot: dict[str, Any]) -> None:
|
||
state["budget_state"] = budget_snapshot
|
||
state["collaboration_budget_history"] = [
|
||
*(state.get("collaboration_budget_history") or []),
|
||
budget_snapshot,
|
||
]
|
||
|
||
|
||
def _budget_metadata(state: AgentState) -> dict[str, Any]:
|
||
budget_state = state.get("budget_state") or {}
|
||
if isinstance(budget_state, CollaborationBudget):
|
||
return dict(budget_state.metadata or {})
|
||
if isinstance(budget_state, dict):
|
||
return dict(budget_state.get("metadata") or {})
|
||
return {}
|
||
|
||
|
||
def _budget_limit(state: AgentState, key: str, default: int) -> int:
|
||
value = _budget_metadata(state).get(key)
|
||
return value if isinstance(value, int) and value > 0 else default
|
||
|
||
|
||
def _append_message_trace(
|
||
state: AgentState,
|
||
*,
|
||
from_agent_id: str,
|
||
to_agent_id: str,
|
||
message_type: str,
|
||
content_summary: str,
|
||
task_id: str | None = None,
|
||
reply_to_message_id: str | None = None,
|
||
payload: dict[str, Any] | None = None,
|
||
message_id: str | None = None,
|
||
message_index: int | None = None,
|
||
) -> dict[str, Any]:
|
||
thread_id = _ensure_message_thread(state)
|
||
if message_id is None or message_index is None:
|
||
message_id, message_index = _bump_message_sequence(state)
|
||
message = AgentMessage(
|
||
message_id=message_id,
|
||
thread_id=thread_id,
|
||
from_agent_id=from_agent_id,
|
||
to_agent_id=to_agent_id,
|
||
task_id=task_id,
|
||
reply_to_message_id=reply_to_message_id,
|
||
message_type=cast(Any, message_type),
|
||
content_summary=content_summary[:400],
|
||
payload=payload or {},
|
||
)
|
||
serialized = message.model_dump(mode="json")
|
||
state["message_trace"] = [*(state.get("message_trace") or []), serialized]
|
||
return serialized
|
||
|
||
|
||
def _create_child_agent(
|
||
state: AgentState,
|
||
*,
|
||
role: AgentRole,
|
||
task: AgentTask,
|
||
) -> str | None:
|
||
parent_agent_id = str(state.get("agent_id") or AgentRole.MASTER.value)
|
||
if not _spawn_permission_for_role(state, role):
|
||
_append_event_trace(
|
||
state,
|
||
"agent.spawn.blocked",
|
||
payload={"reason": "role_policy_blocked", "target_role": role.value},
|
||
severity="warning",
|
||
task_id=task.task_id,
|
||
parent_task_id=task.parent_task_id,
|
||
child_task_id=task.task_id,
|
||
)
|
||
return None
|
||
current_depth = _get_state_int(state, "collaboration_depth")
|
||
if current_depth >= _budget_limit(state, "max_spawn_depth", 1):
|
||
_append_event_trace(
|
||
state,
|
||
"agent.spawn.blocked",
|
||
payload={"reason": "max_spawn_depth_exceeded", "target_role": role.value, "depth": current_depth},
|
||
severity="warning",
|
||
task_id=task.task_id,
|
||
parent_task_id=task.parent_task_id,
|
||
child_task_id=task.task_id,
|
||
)
|
||
return None
|
||
spawned_agent_ids = list(state.get("spawned_agent_ids") or [])
|
||
if len(spawned_agent_ids) >= _budget_limit(state, "max_child_agents", 4):
|
||
_append_event_trace(
|
||
state,
|
||
"agent.spawn.blocked",
|
||
payload={"reason": "max_child_agents_exceeded", "target_role": role.value},
|
||
severity="warning",
|
||
task_id=task.task_id,
|
||
parent_task_id=task.parent_task_id,
|
||
child_task_id=task.task_id,
|
||
)
|
||
return None
|
||
child_agent_id = f"{role.value}-{uuid4().hex[:8]}"
|
||
state["spawned_agent_ids"] = [*spawned_agent_ids, child_agent_id]
|
||
_append_event_trace(
|
||
state,
|
||
"agent.created",
|
||
payload={"parent_agent_id": parent_agent_id, "child_agent_id": child_agent_id, "target_role": role.value},
|
||
task_id=task.task_id,
|
||
parent_task_id=task.parent_task_id,
|
||
child_task_id=task.task_id,
|
||
)
|
||
return child_agent_id
|
||
|
||
|
||
def _collaboration_task_from_state(state: AgentState) -> dict[str, Any]:
|
||
return dict((state.get("turn_context") or {}).get("collaboration_task") or {})
|
||
|
||
|
||
def _record_interrupt(
|
||
state: AgentState,
|
||
*,
|
||
reason: str,
|
||
task: AgentTask | dict[str, Any] | None = None,
|
||
payload: dict[str, Any] | None = None,
|
||
) -> InterruptRecord:
|
||
task_payload = _task_payload(task) if task is not None else _collaboration_task_from_state(state)
|
||
interrupt = InterruptRecord(
|
||
interrupt_id=f"interrupt-{uuid4()}",
|
||
reason=reason,
|
||
requested_by=_role_value(state.get("current_agent")),
|
||
payload=payload or {},
|
||
)
|
||
interrupt_payload = interrupt.model_dump(mode="json")
|
||
state["interrupted_tasks"] = [*(state.get("interrupted_tasks") or []), interrupt_payload]
|
||
state["recovery_points"] = [
|
||
*(state.get("recovery_points") or []),
|
||
{
|
||
"interrupt_id": interrupt.interrupt_id,
|
||
"task_id": task_payload.get("task_id"),
|
||
"thread_id": state.get("thread_id"),
|
||
"message_id": state.get("last_message_id"),
|
||
},
|
||
]
|
||
_append_event_trace(
|
||
state,
|
||
"agent.interrupt.requested",
|
||
payload={"reason": reason, **(payload or {})},
|
||
severity="warning",
|
||
task_id=task_payload.get("task_id"),
|
||
parent_task_id=task_payload.get("parent_task_id"),
|
||
child_task_id=(task_payload.get("child_task_ids") or [None])[0],
|
||
interrupt_id=interrupt.interrupt_id,
|
||
)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.task.interrupted",
|
||
payload={"reason": reason, **(payload or {})},
|
||
severity="warning",
|
||
task_id=task_payload.get("task_id"),
|
||
parent_task_id=task_payload.get("parent_task_id"),
|
||
child_task_id=(task_payload.get("child_task_ids") or [None])[0],
|
||
interrupt_id=interrupt.interrupt_id,
|
||
)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.interrupt.completed",
|
||
payload={"reason": reason, **(payload or {})},
|
||
severity="warning",
|
||
task_id=task_payload.get("task_id"),
|
||
parent_task_id=task_payload.get("parent_task_id"),
|
||
child_task_id=(task_payload.get("child_task_ids") or [None])[0],
|
||
interrupt_id=interrupt.interrupt_id,
|
||
)
|
||
return interrupt
|
||
|
||
|
||
def _record_recovery(
|
||
state: AgentState,
|
||
*,
|
||
interrupt: InterruptRecord,
|
||
strategy: str,
|
||
task: AgentTask | dict[str, Any] | None = None,
|
||
payload: dict[str, Any] | None = None,
|
||
) -> RecoveryRecord:
|
||
task_payload = _task_payload(task) if task is not None else _collaboration_task_from_state(state)
|
||
recovery = RecoveryRecord(
|
||
recovery_id=f"recovery-{uuid4()}",
|
||
source_interrupt_id=interrupt.interrupt_id,
|
||
strategy=strategy,
|
||
resumed_from_task_id=task_payload.get("task_id"),
|
||
resumed_from_thread_id=str(state.get("thread_id") or "") or None,
|
||
payload=payload or {},
|
||
)
|
||
recovery_payload = recovery.model_dump(mode="json")
|
||
state["recovery_trace"] = [*(state.get("recovery_trace") or []), recovery_payload]
|
||
_append_event_trace(
|
||
state,
|
||
"agent.recovery.started",
|
||
payload={"strategy": strategy, **(payload or {})},
|
||
task_id=task_payload.get("task_id"),
|
||
parent_task_id=task_payload.get("parent_task_id"),
|
||
child_task_id=(task_payload.get("child_task_ids") or [None])[0],
|
||
recovery_id=recovery.recovery_id,
|
||
interrupt_id=interrupt.interrupt_id,
|
||
)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.task.recovered",
|
||
payload={"strategy": strategy, **(payload or {})},
|
||
task_id=task_payload.get("task_id"),
|
||
parent_task_id=task_payload.get("parent_task_id"),
|
||
child_task_id=(task_payload.get("child_task_ids") or [None])[0],
|
||
recovery_id=recovery.recovery_id,
|
||
interrupt_id=interrupt.interrupt_id,
|
||
)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.recovery.completed",
|
||
payload={"strategy": strategy, **(payload or {})},
|
||
task_id=task_payload.get("task_id"),
|
||
parent_task_id=task_payload.get("parent_task_id"),
|
||
child_task_id=(task_payload.get("child_task_ids") or [None])[0],
|
||
recovery_id=recovery.recovery_id,
|
||
interrupt_id=interrupt.interrupt_id,
|
||
)
|
||
return recovery
|
||
|
||
|
||
def _spawn_permission_for_role(state: AgentState, role: AgentRole) -> bool:
|
||
indexes = load_builtin_registry_indexes()
|
||
current_role_value = _normalize_current_agent(state.get("current_agent"))
|
||
current_manifest = indexes.agent_by_role_value.get(current_role_value)
|
||
if current_manifest is None or not current_manifest.can_spawn_children:
|
||
return False
|
||
allowed_roles = indexes.spawnable_role_values_by_agent_id.get(current_manifest.agent_id, ())
|
||
return role.value in allowed_roles
|
||
|
||
|
||
def _get_llm_for_state(state: AgentState):
|
||
user_llm_config = state.get("user_llm_config")
|
||
llm = create_llm_from_config(user_llm_config) if user_llm_config else get_llm()
|
||
capabilities = getattr(llm, "_jarvis_provider_capabilities", None)
|
||
if capabilities is None:
|
||
capabilities = (
|
||
resolve_provider_capabilities(user_llm_config)
|
||
if user_llm_config
|
||
else default_provider_capabilities()
|
||
)
|
||
state["provider_capabilities"] = {
|
||
"provider": capabilities.provider,
|
||
"supports_native_tools": capabilities.supports_native_tools,
|
||
"preferred_tool_strategy": capabilities.preferred_tool_strategy,
|
||
}
|
||
return llm
|
||
|
||
|
||
def _resolve_capabilities(state: AgentState, llm) -> Any:
|
||
capabilities = getattr(llm, "_jarvis_provider_capabilities", None)
|
||
if capabilities is None:
|
||
config = state.get("user_llm_config")
|
||
capabilities = resolve_provider_capabilities(config) if config else default_provider_capabilities()
|
||
state["provider_capabilities"] = {
|
||
"provider": capabilities.provider,
|
||
"supports_native_tools": capabilities.supports_native_tools,
|
||
"preferred_tool_strategy": capabilities.preferred_tool_strategy,
|
||
}
|
||
return capabilities
|
||
|
||
|
||
def _filter_user_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||
return [message for message in messages if getattr(message, "type", "") in {"human", "user"}]
|
||
|
||
|
||
def _coalesce_system_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||
system_parts: list[str] = []
|
||
non_system_messages: list[BaseMessage] = []
|
||
|
||
for message in messages:
|
||
if getattr(message, "type", "") == "system":
|
||
text = _stringify_message_content(getattr(message, "content", ""))
|
||
if text:
|
||
system_parts.append(text)
|
||
continue
|
||
non_system_messages.append(message)
|
||
|
||
if not system_parts:
|
||
return non_system_messages
|
||
|
||
return [SystemMessage(content="\n\n".join(system_parts)), *non_system_messages]
|
||
|
||
|
||
def _is_simple_greeting(text: str) -> bool:
|
||
return _normalize_user_text(text) in {"你好", "您好", "早", "早上好", "在吗", "嗨", "hi", "hello"}
|
||
|
||
|
||
def _is_identity_question(text: str) -> bool:
|
||
normalized = _normalize_user_text(text)
|
||
return any(normalized.startswith(pattern) for pattern in IDENTITY_PATTERNS)
|
||
|
||
|
||
def _is_capability_question(text: str) -> bool:
|
||
normalized = _normalize_user_text(text)
|
||
return any(normalized.startswith(pattern) for pattern in CAPABILITY_PATTERNS)
|
||
|
||
|
||
def _is_short_confirmation(text: str) -> bool:
|
||
normalized = _normalize_user_text(text)
|
||
return normalized in SHORT_CONFIRMATION_PATTERNS
|
||
|
||
|
||
def _tool_result_indicates_failure(result: Any) -> bool:
|
||
text = _stringify_message_content(result)
|
||
return "失败" in text
|
||
|
||
|
||
def _latest_assistant_message_content(messages: list[BaseMessage]) -> str:
|
||
previous_assistant_message = next(
|
||
(
|
||
message
|
||
for message in reversed(messages[:-1])
|
||
if getattr(message, "type", "") == "ai"
|
||
),
|
||
None,
|
||
)
|
||
if previous_assistant_message is None:
|
||
return ""
|
||
return _stringify_message_content(getattr(previous_assistant_message, "content", ""))
|
||
|
||
|
||
def _previous_turn_proposed_schedule_creation(messages: list[BaseMessage]) -> bool:
|
||
if len(messages) < 2:
|
||
return False
|
||
content = _latest_assistant_message_content(messages)
|
||
has_schedule_confirmation_hint = any(hint in content for hint in SCHEDULE_CONFIRMATION_HINTS)
|
||
has_question_marker = any(marker in content for marker in SCHEDULE_CONFIRMATION_QUESTION_MARKERS)
|
||
return has_schedule_confirmation_hint and has_question_marker
|
||
|
||
|
||
def _previous_turn_completed_reminder_creation(state: AgentState) -> bool:
|
||
created_entities = state.get("created_entities") or []
|
||
if not created_entities:
|
||
return False
|
||
latest_entity = created_entities[-1]
|
||
if latest_entity.get("type") != "reminder":
|
||
return False
|
||
tool_calls = state.get("tool_calls") or []
|
||
if not tool_calls:
|
||
return False
|
||
latest_tool_call = tool_calls[-1]
|
||
if latest_tool_call.get("name") != "create_reminder":
|
||
return False
|
||
messages = state.get("messages") or []
|
||
if len(messages) < 2:
|
||
return False
|
||
if getattr(messages[-1], "type", "") not in {"human", "user"}:
|
||
return False
|
||
previous_message = messages[-2]
|
||
if getattr(previous_message, "type", "") != "ai":
|
||
return False
|
||
previous_assistant_content = _stringify_message_content(getattr(previous_message, "content", ""))
|
||
completion_markers = ("已创建提醒", "提醒已经创建好了", "帮你设好了这条提醒", "创建成功")
|
||
return any(marker in previous_assistant_content for marker in completion_markers)
|
||
|
||
|
||
def _latest_non_confirmation_user_request(messages: list[BaseMessage]) -> str | None:
|
||
user_messages = [
|
||
_stringify_message_content(getattr(message, "content", "")).strip()
|
||
for message in messages
|
||
if getattr(message, "type", "") in {"human", "user"}
|
||
]
|
||
for content in reversed(user_messages[:-1]):
|
||
if content and not _is_short_confirmation(content):
|
||
return content
|
||
return None
|
||
|
||
|
||
def _expand_schedule_confirmation_query(user_query: str, messages: list[BaseMessage]) -> str:
|
||
previous_request = _latest_non_confirmation_user_request(messages)
|
||
if not previous_request:
|
||
return user_query
|
||
return f"用户确认继续创建上一条提醒安排:{previous_request}"
|
||
|
||
|
||
def _is_schedule_creation_confirmation_response(response_text: str) -> bool:
|
||
content = _stringify_message_content(response_text)
|
||
has_schedule_confirmation_hint = any(hint in content for hint in SCHEDULE_CONFIRMATION_HINTS)
|
||
has_question_marker = any(marker in content for marker in SCHEDULE_CONFIRMATION_QUESTION_MARKERS)
|
||
return has_schedule_confirmation_hint and has_question_marker
|
||
|
||
|
||
def _write_schedule_creation_continuity(state: AgentState, user_query: str) -> None:
|
||
summary = user_query.strip()
|
||
if not summary:
|
||
return
|
||
state["pending_action"] = {
|
||
"type": "schedule_creation",
|
||
"summary": summary,
|
||
"status": "pending",
|
||
}
|
||
state["routing_decision"] = {
|
||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||
"reason": "continue_pending_action",
|
||
}
|
||
state["continuity_state"] = {"status": "fresh"}
|
||
|
||
|
||
def _clear_structured_continuity(state: AgentState) -> None:
|
||
state["pending_action"] = None
|
||
state["routing_decision"] = None
|
||
state["continuity_state"] = None
|
||
|
||
|
||
def _should_clear_schedule_creation_continuity(state: AgentState, created_entities: list[dict[str, Any]]) -> bool:
|
||
if not _has_active_structured_continuation(state):
|
||
return False
|
||
pending_action = state.get("pending_action") or {}
|
||
if pending_action.get("type") != "schedule_creation":
|
||
return False
|
||
return any(entity.get("type") == "reminder" for entity in created_entities)
|
||
|
||
|
||
def _classify_request_signals(user_query: str) -> dict[str, Any]:
|
||
text = (user_query or "").strip().lower()
|
||
has_accounting_signal = any(keyword in text for keyword in ACCOUNTING_INTENT_KEYWORDS)
|
||
has_schedule_signal = not has_accounting_signal and (
|
||
bool(re.search(r"\d{1,2}月\d{1,2}日", text)) or any(keyword in text for keyword in SCHEDULE_KEYWORDS)
|
||
)
|
||
has_analysis_signal = any(keyword in text for keyword in ANALYSIS_KEYWORDS)
|
||
has_knowledge_signal = any(keyword in text for keyword in KNOWLEDGE_KEYWORDS)
|
||
has_execution_signal = has_accounting_signal or any(keyword in text for keyword in EXECUTION_KEYWORDS)
|
||
|
||
roles: list[AgentRole] = []
|
||
if has_knowledge_signal:
|
||
roles.append(AgentRole.LIBRARIAN)
|
||
if has_analysis_signal:
|
||
roles.append(AgentRole.ANALYST)
|
||
if has_schedule_signal:
|
||
roles.append(AgentRole.SCHEDULE_PLANNER)
|
||
if has_execution_signal:
|
||
roles.append(AgentRole.EXECUTOR)
|
||
|
||
return {
|
||
"text": text,
|
||
"has_accounting_signal": has_accounting_signal,
|
||
"has_schedule_signal": has_schedule_signal,
|
||
"has_analysis_signal": has_analysis_signal,
|
||
"has_knowledge_signal": has_knowledge_signal,
|
||
"has_execution_signal": has_execution_signal,
|
||
"roles": roles,
|
||
}
|
||
|
||
|
||
def _select_request_mode(user_query: str) -> tuple[Literal["direct", "collaboration"], dict[str, Any]]:
|
||
signals = _classify_request_signals(user_query)
|
||
roles: list[AgentRole] = signals["roles"]
|
||
text = signals["text"]
|
||
has_multi_step_signal = any(marker in text for marker in COLLABORATION_STEP_MARKERS)
|
||
is_explicit_collaboration_request = "协作" in text or ("拆" in text and "任务" in text)
|
||
is_long_request = len((user_query or "").strip()) >= 24
|
||
|
||
should_collaborate = len(roles) >= 3 or (
|
||
len(roles) >= 2 and (has_multi_step_signal or is_explicit_collaboration_request or is_long_request)
|
||
)
|
||
selected_mode: Literal["direct", "collaboration"] = "collaboration" if should_collaborate else "direct"
|
||
metadata = {
|
||
"mode": selected_mode,
|
||
"reason": "multi_role_request" if should_collaborate else "single_role_or_simple_request",
|
||
"roles": [role.value for role in roles],
|
||
"multi_step": has_multi_step_signal,
|
||
}
|
||
return selected_mode, metadata
|
||
|
||
|
||
def _route_agent_from_user_query(user_query: str) -> AgentRole:
|
||
signals = _classify_request_signals(user_query)
|
||
text = signals["text"]
|
||
|
||
if signals["has_accounting_signal"]:
|
||
return AgentRole.EXECUTOR
|
||
|
||
if signals["has_schedule_signal"]:
|
||
return AgentRole.SCHEDULE_PLANNER
|
||
|
||
if signals["has_analysis_signal"]:
|
||
return AgentRole.ANALYST
|
||
|
||
if signals["has_knowledge_signal"]:
|
||
return AgentRole.LIBRARIAN
|
||
|
||
if any(pattern in text for pattern in GENERAL_QA_PATTERNS):
|
||
return AgentRole.MASTER
|
||
|
||
if signals["has_execution_signal"]:
|
||
return AgentRole.EXECUTOR
|
||
return AgentRole.MASTER
|
||
|
||
|
||
def _choose_sub_commander(role: AgentRole, user_query: str) -> str:
|
||
text = (user_query or "").strip().lower()
|
||
|
||
if role == AgentRole.SCHEDULE_PLANNER:
|
||
if re.search(r"\d{1,2}月\d{1,2}日", text) or any(keyword in text for keyword in SCHEDULE_PLANNING_KEYWORDS):
|
||
return "schedule_planning"
|
||
return "schedule_analysis"
|
||
if role == AgentRole.EXECUTOR:
|
||
if any(keyword in text for keyword in ("论坛", "帖子", "发帖", "指令")):
|
||
return "executor_forum"
|
||
return "executor_tasks"
|
||
if role == AgentRole.LIBRARIAN:
|
||
if any(keyword in text for keyword in ("图谱", "关系", "沉淀", "graph")):
|
||
return "librarian_graph"
|
||
return "librarian_retrieval"
|
||
if role == AgentRole.ANALYST:
|
||
if any(keyword in text for keyword in ("趋势", "风险", "洞察", "建议", "机会")):
|
||
return "analyst_insights"
|
||
return "analyst_progress"
|
||
raise ValueError(f"unsupported role: {role}")
|
||
|
||
|
||
def _is_missing_knowledge_result(tool_result: str | None) -> bool:
|
||
text = (tool_result or "").strip()
|
||
if not text:
|
||
return True
|
||
markers = (
|
||
"未找到相关知识",
|
||
"知识库可能为空",
|
||
"未找到相关网页结果",
|
||
"暂无相关记录",
|
||
"没有找到",
|
||
|
||
)
|
||
return any(marker in text for marker in markers)
|
||
|
||
|
||
def _extract_json_object(content: str) -> str | None:
|
||
text = (content or "").strip()
|
||
if not text:
|
||
return None
|
||
|
||
fenced_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
|
||
if fenced_match:
|
||
return fenced_match.group(1)
|
||
|
||
decoder = json.JSONDecoder()
|
||
for index, char in enumerate(text):
|
||
if char != "{":
|
||
continue
|
||
try:
|
||
_, end = decoder.raw_decode(text[index:])
|
||
return text[index:index + end]
|
||
except json.JSONDecodeError:
|
||
continue
|
||
return None
|
||
|
||
|
||
def _parse_json_action(content: str, allowed_tools: list[str]) -> dict[str, Any] | None:
|
||
json_text = _extract_json_object(_stringify_message_content(content))
|
||
if not json_text:
|
||
return None
|
||
|
||
try:
|
||
payload = json.loads(json_text)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
if not isinstance(payload, dict):
|
||
return None
|
||
|
||
mode = payload.get("mode")
|
||
if mode == "tool_call":
|
||
tool_calls: list[dict[str, Any]] = []
|
||
raw_tool_calls = payload.get("tool_calls", [])
|
||
if not isinstance(raw_tool_calls, list):
|
||
return None
|
||
for item in raw_tool_calls:
|
||
if not isinstance(item, dict):
|
||
return None
|
||
name = item.get("name")
|
||
if name not in allowed_tools:
|
||
return None
|
||
args = item.get("arguments")
|
||
if args is None:
|
||
args = item.get("parameters")
|
||
if not isinstance(args, dict):
|
||
return None
|
||
tool_calls.append(
|
||
{
|
||
"name": name,
|
||
"args": args,
|
||
"reason": item.get("reason"),
|
||
}
|
||
)
|
||
return {"mode": mode, "tool_calls": tool_calls}
|
||
|
||
if mode == "final" and isinstance(payload.get("final_response"), str):
|
||
return {"mode": mode, "final_response": payload["final_response"]}
|
||
|
||
if mode == "clarification" and isinstance(payload.get("clarification_question"), str):
|
||
return {"mode": mode, "clarification_question": payload["clarification_question"]}
|
||
|
||
return None
|
||
|
||
|
||
def _has_active_structured_continuation(state: AgentState) -> bool:
|
||
pending_action = state.get("pending_action") or {}
|
||
routing_decision = state.get("routing_decision") or {}
|
||
continuity_state = state.get("continuity_state") or {}
|
||
|
||
if continuity_state.get("status") != "fresh":
|
||
return False
|
||
if pending_action.get("status") != "pending":
|
||
return False
|
||
if routing_decision.get("reason") != "continue_pending_action":
|
||
return False
|
||
target_agent = routing_decision.get("target_agent")
|
||
return target_agent in _role_values()
|
||
|
||
|
||
def _route_from_structured_continuity(state: AgentState, user_query: str) -> AgentRole | None:
|
||
if not _is_short_confirmation(user_query):
|
||
return None
|
||
if not _has_active_structured_continuation(state):
|
||
return None
|
||
|
||
target_agent = (state.get("routing_decision") or {}).get("target_agent")
|
||
|
||
try:
|
||
return AgentRole(str(target_agent))
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
def _build_structured_continuity_summary(state: AgentState) -> str | None:
|
||
pending_action = state.get("pending_action")
|
||
routing_decision = state.get("routing_decision")
|
||
|
||
if not pending_action or not routing_decision:
|
||
return None
|
||
if not _has_active_structured_continuation(state):
|
||
return None
|
||
|
||
action_type = str(pending_action.get("type") or "unknown")
|
||
action_summary = str(pending_action.get("summary") or "")
|
||
routing_reason = str(routing_decision.get("reason") or "")
|
||
target_agent = str(routing_decision.get("target_agent") or "")
|
||
|
||
lines = [
|
||
"structured_continuity:",
|
||
f"- pending_action.type: {action_type}",
|
||
]
|
||
if action_summary:
|
||
lines.append(f"- pending_action.summary: {action_summary}")
|
||
if target_agent:
|
||
lines.append(f"- routing_decision.target_agent: {target_agent}")
|
||
if routing_reason:
|
||
lines.append(f"- routing_decision.reason: {routing_reason}")
|
||
lines.append("- instruction: continue_pending_action")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _build_system_messages(state: AgentState, system_prompt: str, role: AgentRole, sub_commander: str) -> list[BaseMessage]:
|
||
messages: list[BaseMessage] = [SystemMessage(content=system_prompt)]
|
||
|
||
current_datetime_context = state.get("current_datetime_context")
|
||
if current_datetime_context:
|
||
messages.append(SystemMessage(content=current_datetime_context))
|
||
|
||
continuity_summary = _build_structured_continuity_summary(state)
|
||
if continuity_summary:
|
||
messages.append(SystemMessage(content=continuity_summary))
|
||
|
||
clarification_summary = _build_clarification_summary(state)
|
||
if clarification_summary:
|
||
messages.append(SystemMessage(content=clarification_summary))
|
||
|
||
collaboration_summary = _build_collaboration_context_summary(state)
|
||
if collaboration_summary:
|
||
messages.append(SystemMessage(content=collaboration_summary))
|
||
|
||
role_context_map = {
|
||
AgentRole.SCHEDULE_PLANNER: state.get("schedule_context_summary"),
|
||
AgentRole.LIBRARIAN: state.get("knowledge_context"),
|
||
AgentRole.ANALYST: state.get("analysis_report"),
|
||
}
|
||
role_context = role_context_map.get(role)
|
||
if role_context:
|
||
messages.append(SystemMessage(content=f"角色上下文:\n{role_context}"))
|
||
|
||
role_skill_key = ROLE_SKILL_CONTEXT.get(role)
|
||
if role_skill_key:
|
||
skill_context = build_skill_context(role_skill_key)
|
||
if skill_context:
|
||
messages.append(SystemMessage(content=skill_context))
|
||
|
||
messages.append(SystemMessage(content=f"本次应由子指挥官 `{sub_commander}` 接手。"))
|
||
messages.append(SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]))
|
||
return messages
|
||
|
||
|
||
def _assign_agent_for_task_role(role_value: str | None) -> AgentRole:
|
||
try:
|
||
return AgentRole(str(role_value or AgentRole.MASTER.value))
|
||
except ValueError:
|
||
return AgentRole.MASTER
|
||
|
||
|
||
def _build_task_title(role: AgentRole) -> str:
|
||
titles = {
|
||
AgentRole.LIBRARIAN: "补齐事实与证据",
|
||
AgentRole.ANALYST: "给出分析与判断",
|
||
AgentRole.SCHEDULE_PLANNER: "形成安排与计划",
|
||
AgentRole.EXECUTOR: "执行必要动作",
|
||
}
|
||
return titles.get(role, "处理协作任务")
|
||
|
||
|
||
def _build_task_goal(role: AgentRole, user_query: str) -> str:
|
||
goals = {
|
||
AgentRole.LIBRARIAN: f"围绕用户请求检索并整理能直接支撑结论的证据:{user_query}",
|
||
AgentRole.ANALYST: f"基于已有上下文与证据,对用户请求给出结论、风险和建议:{user_query}",
|
||
AgentRole.SCHEDULE_PLANNER: f"把用户请求收束为明确安排、节奏或日程建议:{user_query}",
|
||
AgentRole.EXECUTOR: f"基于用户请求和前序结论执行必要动作,并回收结果:{user_query}",
|
||
}
|
||
return goals.get(role, user_query)
|
||
|
||
|
||
def _build_expected_evidence(role: AgentRole) -> list[dict[str, Any]]:
|
||
evidence_map = {
|
||
AgentRole.LIBRARIAN: [{"type": "evidence", "detail": "retrieval findings or source-backed context"}],
|
||
AgentRole.ANALYST: [{"type": "analysis", "detail": "judgment with supporting rationale"}],
|
||
AgentRole.SCHEDULE_PLANNER: [{"type": "plan", "detail": "explicit schedule or next-step proposal"}],
|
||
AgentRole.EXECUTOR: [{"type": "execution", "detail": "tool output or execution confirmation"}],
|
||
}
|
||
return evidence_map.get(role, [{"type": "summary", "detail": "task evidence"}])
|
||
|
||
|
||
def _build_collaboration_tasks(user_query: str) -> list[AgentTask]:
|
||
_, metadata = _select_request_mode(user_query)
|
||
raw_roles = metadata.get("roles") or []
|
||
roles = sorted(
|
||
(_assign_agent_for_task_role(role_value) for role_value in raw_roles),
|
||
key=lambda role: COLLABORATION_ROLE_ORDER.get(role, 99),
|
||
)
|
||
unique_roles: list[AgentRole] = []
|
||
for role in roles:
|
||
if role not in unique_roles and role != AgentRole.MASTER:
|
||
unique_roles.append(role)
|
||
|
||
parent_task_id = f"task-root-{uuid4().hex[:8]}"
|
||
task_ids = [f"task-{index}-{uuid4().hex[:8]}" for index, _ in enumerate(unique_roles[:4], start=1)]
|
||
tasks: list[AgentTask] = []
|
||
for index, role in enumerate(unique_roles[:4]):
|
||
tasks.append(
|
||
AgentTask(
|
||
task_id=task_ids[index],
|
||
title=_build_task_title(role),
|
||
owner_agent_id=role.value,
|
||
role=role.value,
|
||
goal=_build_task_goal(role, user_query),
|
||
parent_task_id=parent_task_id,
|
||
child_task_ids=[task_ids[index + 1]] if index + 1 < len(task_ids) else [],
|
||
expected_evidence=_build_expected_evidence(role),
|
||
)
|
||
)
|
||
return tasks
|
||
|
||
|
||
def _build_collaboration_context_summary(state: AgentState) -> str | None:
|
||
if state.get("execution_mode") != "collaboration":
|
||
return None
|
||
|
||
lines = [COORDINATOR_SYSTEM_PROMPT]
|
||
current_task = (state.get("turn_context") or {}).get("collaboration_task") or {}
|
||
if current_task:
|
||
lines.extend(
|
||
[
|
||
"current_collaboration_task:",
|
||
f"- task_id: {current_task.get('task_id')}",
|
||
f"- title: {current_task.get('title')}",
|
||
f"- role: {current_task.get('role')}",
|
||
f"- goal: {current_task.get('goal')}",
|
||
]
|
||
)
|
||
|
||
task_results = state.get("task_results") or []
|
||
if task_results:
|
||
lines.append("completed_collaboration_results:")
|
||
for item in task_results[-4:]:
|
||
normalized = item.model_dump(mode="json") if isinstance(item, TaskResult) else dict(item)
|
||
lines.append(
|
||
f"- {normalized.get('task_id')} | {normalized.get('status')} | {normalized.get('summary') or ''}"
|
||
)
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _maybe_reset_turn_budgets(state: AgentState) -> None:
|
||
messages = state.get("messages") or []
|
||
if not messages:
|
||
state["routing_hops"] = 0
|
||
state["terminated_due_to_loop_guard"] = False
|
||
state["iteration_count"] = 0
|
||
state["tool_round_count"] = 0
|
||
state["retry_count"] = 0
|
||
state["stop_reason"] = None
|
||
state["clarification_needed"] = False
|
||
state["clarification_question"] = None
|
||
state["final_response"] = None
|
||
return
|
||
|
||
last_message_type = getattr(messages[-1], "type", "")
|
||
has_prior_assistant_turn = any(getattr(message, "type", "") == "ai" for message in messages[:-1])
|
||
if last_message_type in {"human", "user"} and has_prior_assistant_turn:
|
||
state["routing_hops"] = 0
|
||
state["terminated_due_to_loop_guard"] = False
|
||
state["iteration_count"] = 0
|
||
state["tool_round_count"] = 0
|
||
state["retry_count"] = 0
|
||
state["stop_reason"] = None
|
||
state["clarification_needed"] = False
|
||
state["clarification_question"] = None
|
||
state["final_response"] = None
|
||
|
||
|
||
def _conversation_history_messages(state: AgentState) -> list[BaseMessage]:
|
||
history = list(state.get("messages", []))
|
||
return [message for message in history if getattr(message, "type", "") != "system"]
|
||
|
||
|
||
def _append_event_trace(
|
||
state: AgentState,
|
||
event_type: str,
|
||
*,
|
||
payload: dict[str, Any] | None = None,
|
||
severity: str = "info",
|
||
task_id: str | None = None,
|
||
parent_task_id: str | None = None,
|
||
child_task_id: str | None = None,
|
||
interrupt_id: str | None = None,
|
||
recovery_id: str | None = None,
|
||
message_id: str | None = None,
|
||
) -> None:
|
||
thread_id = _ensure_message_thread(state)
|
||
event = AgentEvent(
|
||
event_id=f"evt-{uuid4()}",
|
||
event_type=cast(Any, event_type),
|
||
conversation_id=str(state.get("conversation_id") or "") or None,
|
||
agent_id=_role_value(state.get("current_agent")),
|
||
sub_commander_id=state.get("current_sub_commander"),
|
||
task_id=task_id,
|
||
parent_task_id=parent_task_id,
|
||
child_task_id=child_task_id,
|
||
thread_id=thread_id,
|
||
message_id=message_id or (str(state.get("last_message_id") or "") or None),
|
||
interrupt_id=interrupt_id,
|
||
recovery_id=recovery_id,
|
||
payload=payload or {},
|
||
severity=cast(Any, severity),
|
||
)
|
||
state["event_trace"] = [
|
||
*(state.get("event_trace") or []),
|
||
event.model_dump(mode="json"),
|
||
]
|
||
|
||
|
||
def _capability_manifest_for_tool(tool_name: str):
|
||
indexes = load_builtin_registry_indexes()
|
||
capability_id = indexes.capability_id_by_tool_name.get(tool_name)
|
||
if capability_id is None:
|
||
return None
|
||
return indexes.capability_by_id.get(capability_id)
|
||
|
||
|
||
def _build_verifier_hints(state: AgentState, tool_name: str, result: Any) -> dict[str, Any]:
|
||
capability = _capability_manifest_for_tool(tool_name)
|
||
permission_class = getattr(capability, "permission_class", None)
|
||
side_effect_scope = getattr(capability, "side_effect_scope", None)
|
||
return {
|
||
"tool_name": tool_name,
|
||
"permission_class": getattr(permission_class, "value", None),
|
||
"side_effect_scope": getattr(side_effect_scope, "value", None),
|
||
"requires_confirmation": bool(getattr(capability, "requires_confirmation", False)),
|
||
"supports_retry": bool(getattr(capability, "supports_retry", False)),
|
||
"safe_for_parallel_use": bool(getattr(capability, "safe_for_parallel_use", False)),
|
||
"result_preview": _stringify_message_content(result)[:200],
|
||
}
|
||
|
||
|
||
def _update_task_result_summary(state: AgentState, tool_summaries: list[dict[str, Any]]) -> None:
|
||
if not tool_summaries:
|
||
return
|
||
|
||
previous_summary = state.get("task_result_summary") or {}
|
||
previous_tools = previous_summary.get("tools") or []
|
||
merged_tools = [*previous_tools, *tool_summaries]
|
||
summary = {
|
||
"tool_count": len(merged_tools),
|
||
"tools": merged_tools,
|
||
"created_count": sum(int(item.get("created_count") or 0) for item in merged_tools),
|
||
"created_entity_types": [
|
||
entity_type
|
||
for item in merged_tools
|
||
for entity_type in item.get("created_entity_types") or []
|
||
if entity_type
|
||
],
|
||
"stop_reason": state.get("stop_reason"),
|
||
}
|
||
state["task_result_summary"] = summary
|
||
state["action_results"] = [*(state.get("action_results") or []), summary]
|
||
|
||
|
||
def _record_sub_commander(state: AgentState, role: AgentRole, sub_commander: str, user_query: str) -> None:
|
||
thread_id = _ensure_message_thread(state)
|
||
message_id, message_index = _bump_message_sequence(state)
|
||
current_agent_id = str(state.get("agent_id") or AgentRole.MASTER.value)
|
||
state["current_agent"] = role.value
|
||
state["current_sub_commander"] = sub_commander
|
||
state["active_agents"] = _normalize_active_agents(state.get("active_agents"))
|
||
if role not in state["active_agents"]:
|
||
state["active_agents"] = [*state["active_agents"], role]
|
||
state["active_sub_commanders"] = [*(state.get("active_sub_commanders") or []), sub_commander]
|
||
state["agent_trace"] = [*(state.get("agent_trace") or []), role.value]
|
||
trace_entry = {
|
||
"agent": _role_value(role),
|
||
"agent_id": current_agent_id,
|
||
"sub_commander": sub_commander,
|
||
"query": user_query,
|
||
"thread_id": thread_id,
|
||
"message_id": message_id,
|
||
"message_index": message_index,
|
||
}
|
||
state["sub_commander_trace"] = [*(state.get("sub_commander_trace") or []), trace_entry]
|
||
state["retrieval_trace"] = [*(state.get("retrieval_trace") or []), trace_entry]
|
||
_append_message_trace(
|
||
state,
|
||
from_agent_id=str(state.get("parent_agent_id") or AgentRole.MASTER.value),
|
||
to_agent_id=current_agent_id,
|
||
message_type="task_request",
|
||
content_summary=user_query,
|
||
task_id=((state.get("turn_context") or {}).get("collaboration_task") or {}).get("task_id"),
|
||
payload={"sub_commander": sub_commander},
|
||
message_id=message_id,
|
||
message_index=message_index,
|
||
)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.message.received",
|
||
payload={"sub_commander": sub_commander, "summary": user_query[:200]},
|
||
task_id=((state.get("turn_context") or {}).get("collaboration_task") or {}).get("task_id"),
|
||
message_id=message_id,
|
||
)
|
||
|
||
|
||
def _stop_sub_commander_due_to_budget(state: AgentState, reason: str) -> None:
|
||
state["stop_reason"] = reason
|
||
state["final_response"] = "这次需要处理的步骤有点多,我先停在这里。您可以把目标再明确一点,或让我先只完成其中一步。"
|
||
|
||
|
||
def _guard_sub_commander_budget(state: AgentState, counter_key: str, max_key: str, reason: str) -> bool:
|
||
max_value = _get_state_int(state, max_key)
|
||
current_value = _get_state_int(state, counter_key)
|
||
if max_value > 0 and current_value >= max_value:
|
||
_stop_sub_commander_due_to_budget(state, reason)
|
||
return False
|
||
return True
|
||
|
||
|
||
def _classify_created_entity(tool_name: str) -> dict[str, str] | None:
|
||
mapping = {
|
||
"create_reminder": "reminder",
|
||
"create_goal": "goal",
|
||
"create_todo": "todo",
|
||
"create_schedule_task": "task",
|
||
"create_task": "task",
|
||
"create_forum_post": "forum_post",
|
||
}
|
||
entity_type = mapping.get(tool_name)
|
||
if not entity_type:
|
||
return None
|
||
return {"type": entity_type, "tool": tool_name}
|
||
|
||
|
||
def _build_reminder_time_clarification(args: dict[str, Any]) -> dict[str, Any] | None:
|
||
if not args.get("date") or args.get("reminder_at"):
|
||
return None
|
||
title = str(args.get("title") or args.get("content") or "这件事").strip() or "这件事"
|
||
return {
|
||
"question": f"要把“{title}”提醒在几点?如果您不想特地指定,我也可以默认按当天早上 9 点给您设置。",
|
||
"missing_fields": ["reminder_at"],
|
||
"partial_args": args,
|
||
}
|
||
|
||
|
||
def _has_active_clarification_context(state: AgentState) -> bool:
|
||
clarification_context = state.get("clarification_context") or {}
|
||
if not clarification_context:
|
||
return False
|
||
if not clarification_context.get("question"):
|
||
return False
|
||
owning_agent = clarification_context.get("owning_agent")
|
||
return isinstance(owning_agent, str) and owning_agent in _role_values()
|
||
|
||
|
||
def _clear_clarification_context(state: AgentState) -> None:
|
||
state["clarification_context"] = None
|
||
|
||
|
||
def _write_clarification_context(
|
||
state: AgentState,
|
||
*,
|
||
role: AgentRole,
|
||
sub_commander: str,
|
||
tool_name: str,
|
||
question: str,
|
||
partial_args: dict[str, Any] | None = None,
|
||
missing_fields: list[str] | None = None,
|
||
) -> None:
|
||
state["clarification_context"] = {
|
||
"owning_agent": role.value,
|
||
"owning_sub_commander": sub_commander,
|
||
"target_action": tool_name,
|
||
"question": question,
|
||
"partial_args": dict(partial_args or {}),
|
||
"missing_fields": list(missing_fields or []),
|
||
"status": "pending",
|
||
}
|
||
pending_action = state.get("pending_action") or {}
|
||
if pending_action.get("status") == "pending":
|
||
state["continuity_state"] = {"status": "fresh", "mode": "resume_after_clarification"}
|
||
|
||
|
||
def _route_from_clarification_context(state: AgentState, user_query: str) -> AgentRole | None:
|
||
if not _has_active_clarification_context(state):
|
||
return None
|
||
if _route_agent_from_user_query(user_query) != AgentRole.MASTER:
|
||
return None
|
||
owning_agent = str((state.get("clarification_context") or {}).get("owning_agent") or "")
|
||
try:
|
||
return AgentRole(owning_agent)
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
def _build_clarification_summary(state: AgentState) -> str | None:
|
||
clarification_context = state.get("clarification_context") or {}
|
||
if not _has_active_clarification_context(state):
|
||
return None
|
||
lines = [
|
||
"clarification_context:",
|
||
f"- owning_agent: {clarification_context.get('owning_agent')}",
|
||
f"- owning_sub_commander: {clarification_context.get('owning_sub_commander')}",
|
||
f"- target_action: {clarification_context.get('target_action')}",
|
||
f"- question: {clarification_context.get('question')}",
|
||
]
|
||
missing_fields = clarification_context.get("missing_fields") or []
|
||
partial_args = clarification_context.get("partial_args") or {}
|
||
if missing_fields:
|
||
lines.append(f"- missing_fields: {', '.join(str(field) for field in missing_fields)}")
|
||
if partial_args:
|
||
lines.append(f"- partial_args: {json.dumps(partial_args, ensure_ascii=False)}")
|
||
lines.append("- instruction: merge the user's latest answer into the missing fields and continue the same action")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _build_resumed_clarification_query(state: AgentState, user_query: str) -> str:
|
||
clarification_context = state.get("clarification_context") or {}
|
||
partial_args = clarification_context.get("partial_args") or {}
|
||
target_action = clarification_context.get("target_action") or ""
|
||
if not partial_args or target_action != "create_reminder":
|
||
return user_query
|
||
|
||
parts: list[str] = []
|
||
title = partial_args.get("title") or partial_args.get("content")
|
||
if title:
|
||
parts.append(f"title={title}")
|
||
date = partial_args.get("date")
|
||
if date:
|
||
parts.append(f"date={date}")
|
||
parts.append(f"reminder_at={user_query}")
|
||
return f"继续完成提醒创建,请合并参数:{';'.join(parts)}"
|
||
|
||
|
||
def _prepare_tool_calls_for_execution(
|
||
tool_calls: list[dict[str, Any]],
|
||
state: AgentState,
|
||
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
|
||
prepared_calls: list[dict[str, Any]] = []
|
||
for call in tool_calls:
|
||
tool_name, args = _canonicalize_tool_call(call["name"], dict(call.get("args") or {}))
|
||
normalized_args = normalize_tool_time_arguments(
|
||
tool_name,
|
||
args,
|
||
state.get("current_datetime_context"),
|
||
)
|
||
clarification = None
|
||
if tool_name == "create_reminder":
|
||
clarification = _build_reminder_time_clarification(normalized_args)
|
||
if clarification:
|
||
return [], {
|
||
"tool_name": tool_name,
|
||
"question": clarification["question"],
|
||
"partial_args": clarification.get("partial_args") or normalized_args,
|
||
"missing_fields": clarification.get("missing_fields") or [],
|
||
}
|
||
prepared_calls.append(
|
||
{
|
||
"id": call.get("id"),
|
||
"name": tool_name,
|
||
"args": normalized_args,
|
||
"reason": call.get("reason"),
|
||
}
|
||
)
|
||
return prepared_calls, None
|
||
|
||
|
||
def _canonicalize_tool_call(tool_name: str, args: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||
normalized_name = tool_name
|
||
normalized_args = dict(args)
|
||
|
||
if normalized_name == "create_reminder":
|
||
if normalized_args.get("description") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["content"] = normalized_args["description"]
|
||
if normalized_args.get("reminder_content") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["content"] = normalized_args["reminder_content"]
|
||
if normalized_args.get("reminder_time") and not (
|
||
normalized_args.get("reminder_at")
|
||
or normalized_args.get("datetime")
|
||
or normalized_args.get("at")
|
||
or normalized_args.get("remind_at")
|
||
or normalized_args.get("time")
|
||
):
|
||
normalized_args["time"] = normalized_args["reminder_time"]
|
||
|
||
if normalized_name in {"create_schedule_task", "create_task"}:
|
||
if normalized_args.get("task") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["title"] = normalized_args["task"]
|
||
if normalized_args.get("due_datetime") and not (normalized_args.get("due_date") or normalized_args.get("date")):
|
||
normalized_args["due_date"] = normalized_args["due_datetime"]
|
||
if normalized_args.get("due_time") and not normalized_args.get("due_date"):
|
||
normalized_args["due_date"] = normalized_args["due_time"]
|
||
|
||
if normalized_name == "create_todo":
|
||
if normalized_args.get("task") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["title"] = normalized_args["task"]
|
||
if normalized_args.get("date") and not normalized_args.get("todo_date"):
|
||
normalized_args["todo_date"] = normalized_args["date"]
|
||
if normalized_args.get("due_date") and not normalized_args.get("todo_date"):
|
||
normalized_args["todo_date"] = normalized_args["due_date"]
|
||
if normalized_args.get("due_datetime") and not normalized_args.get("todo_date"):
|
||
normalized_args["todo_date"] = normalized_args["due_datetime"]
|
||
if any(normalized_args.get(key) for key in ("due_datetime", "start_time", "end_time", "due_time")):
|
||
normalized_name = "create_schedule_task"
|
||
if normalized_args.get("due_time") and not normalized_args.get("due_date"):
|
||
normalized_args["due_date"] = normalized_args["due_time"]
|
||
elif normalized_args.get("todo_date") and not normalized_args.get("due_date"):
|
||
normalized_args["due_date"] = normalized_args["todo_date"]
|
||
|
||
if normalized_name == "create_goal":
|
||
if normalized_args.get("task") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["title"] = normalized_args["task"]
|
||
if normalized_args.get("date") and not normalized_args.get("goal_date"):
|
||
normalized_args["goal_date"] = normalized_args["date"]
|
||
if normalized_args.get("target_date") and not normalized_args.get("goal_date"):
|
||
normalized_args["goal_date"] = normalized_args["target_date"]
|
||
|
||
return normalized_name, normalized_args
|
||
|
||
|
||
async def _invoke_llm(llm, messages: list[BaseMessage], tools: list[Any] | None = None):
|
||
messages = _coalesce_system_messages(messages)
|
||
if tools:
|
||
llm = llm.bind_tools(tools)
|
||
return await llm.ainvoke(messages)
|
||
|
||
|
||
async def _execute_tool_calls(
|
||
tool_calls: list[dict[str, Any]],
|
||
toolset: list[Any],
|
||
state: AgentState,
|
||
) -> tuple[list[dict[str, Any]], str, list[dict[str, str]], list[ToolMessage]]:
|
||
tool_map = {tool.name: tool for tool in toolset}
|
||
normalized_calls: list[dict[str, Any]] = []
|
||
result_lines: list[str] = []
|
||
created_entities: list[dict[str, str]] = []
|
||
tool_messages: list[ToolMessage] = []
|
||
verifier_hints_by_tool: list[dict[str, Any]] = []
|
||
tool_summaries: list[dict[str, Any]] = []
|
||
|
||
for call in tool_calls:
|
||
tool_name = call["name"]
|
||
normalized_args = dict(call.get("args") or {})
|
||
tool = tool_map.get(tool_name)
|
||
if tool is None:
|
||
raise ValueError(f"Tool not found: {tool_name}")
|
||
|
||
_append_event_trace(
|
||
state,
|
||
"agent.tool.start",
|
||
payload={"tool_name": tool_name, "args": normalized_args},
|
||
task_id=str(call.get("id") or "") or None,
|
||
)
|
||
|
||
try:
|
||
if hasattr(tool, "ainvoke"):
|
||
result = await tool.ainvoke(normalized_args)
|
||
else:
|
||
result = await asyncio.to_thread(tool.invoke, normalized_args)
|
||
except Exception as exc:
|
||
logger.exception("Tool execution failed: %s args=%s", tool_name, normalized_args)
|
||
result = f"工具执行失败: {exc}"
|
||
_append_event_trace(
|
||
state,
|
||
"agent.error",
|
||
payload={"tool_name": tool_name, "args": normalized_args, "error": str(exc)},
|
||
severity="error",
|
||
task_id=str(call.get("id") or "") or None,
|
||
)
|
||
|
||
normalized_call = {
|
||
"id": call.get("id"),
|
||
"name": tool_name,
|
||
"args": normalized_args,
|
||
"reason": call.get("reason"),
|
||
}
|
||
normalized_calls.append(normalized_call)
|
||
result_lines.append(f"[{tool_name}] {result}")
|
||
verifier_hints = _build_verifier_hints(state, tool_name, result)
|
||
verifier_hints_by_tool.append(verifier_hints)
|
||
tool_outcome = {
|
||
"tool_name": tool_name,
|
||
"args": normalized_args,
|
||
"result_preview": _stringify_message_content(result)[:200],
|
||
"verifier_hints": verifier_hints,
|
||
}
|
||
state["tool_outcomes"] = [*(state.get("tool_outcomes") or []), tool_outcome]
|
||
_append_event_trace(
|
||
state,
|
||
"agent.tool.result",
|
||
payload={
|
||
"tool_name": tool_name,
|
||
"args": normalized_args,
|
||
"result_preview": _stringify_message_content(result)[:200],
|
||
"verification": verifier_hints,
|
||
},
|
||
severity="error" if _tool_result_indicates_failure(result) else "info",
|
||
task_id=str(call.get("id") or "") or None,
|
||
)
|
||
tool_messages.append(
|
||
ToolMessage(
|
||
content=_stringify_message_content(result),
|
||
tool_call_id=str(call.get("id") or tool_name),
|
||
name=tool_name,
|
||
)
|
||
)
|
||
entity = _classify_created_entity(tool_name)
|
||
call_created_entities: list[dict[str, str]] = []
|
||
if entity and not _tool_result_indicates_failure(result):
|
||
created_entities.append(entity)
|
||
call_created_entities.append(entity)
|
||
tool_summaries.append(
|
||
{
|
||
"tool_name": tool_name,
|
||
"result_preview": _stringify_message_content(result)[:200],
|
||
"created_entity_types": [entity.get("type") for entity in call_created_entities if entity.get("type")],
|
||
"created_count": len(call_created_entities),
|
||
}
|
||
)
|
||
|
||
state["verifier_hints"] = {"tools": verifier_hints_by_tool}
|
||
_update_task_result_summary(state, tool_summaries)
|
||
return normalized_calls, "\n".join(result_lines), created_entities, tool_messages
|
||
|
||
|
||
async def _run_sub_commander(
|
||
state: AgentState,
|
||
role: AgentRole,
|
||
manager_prompt: str,
|
||
user_query: str,
|
||
*,
|
||
use_tools: bool,
|
||
summary_target: str | None = None,
|
||
):
|
||
state["clarification_needed"] = False
|
||
state["clarification_question"] = None
|
||
state["stop_reason"] = None
|
||
|
||
llm = _get_llm_for_state(state)
|
||
capabilities = _resolve_capabilities(state, llm)
|
||
if _has_active_clarification_context(state) and role.value == str((state.get("clarification_context") or {}).get("owning_agent") or ""):
|
||
user_query = _build_resumed_clarification_query(state, user_query)
|
||
sub_commander = _choose_sub_commander(role, user_query)
|
||
_record_sub_commander(state, role, sub_commander, user_query)
|
||
|
||
toolset = SUB_COMMANDER_TOOLSETS.get(sub_commander, []) if use_tools else []
|
||
if (
|
||
role == AgentRole.EXECUTOR
|
||
and _is_short_confirmation(user_query)
|
||
and _previous_turn_completed_reminder_creation(state)
|
||
):
|
||
state["tool_calls"] = []
|
||
state["last_tool_result"] = None
|
||
state["final_response"] = "上一条提醒已经创建好了。若您现在要新建别的内容,请直接告诉我要创建什么。"
|
||
history_messages = list(state.get("messages", []))
|
||
history_messages.append(AIMessage(content=state["final_response"]))
|
||
state["messages"] = history_messages
|
||
state["should_respond"] = True
|
||
return state
|
||
base_messages = _build_system_messages(state, manager_prompt, role, sub_commander)
|
||
conversation_history = _conversation_history_messages(state)
|
||
if conversation_history and getattr(conversation_history[-1], "type", "") in {"human", "user"}:
|
||
conversation_history = conversation_history[:-1]
|
||
user_message = HumanMessage(content=f"用户请求: {user_query}")
|
||
working_messages = [*base_messages, *conversation_history, user_message]
|
||
|
||
state["tool_calls"] = []
|
||
state["last_tool_result"] = None
|
||
state["tool_strategy_used"] = None
|
||
state["fallback_parse_error"] = None
|
||
start_tool_index = len(state.get("tool_outcomes") or [])
|
||
|
||
if not _guard_sub_commander_budget(state, "tool_round_count", "max_tool_rounds", "max_tool_rounds_exceeded"):
|
||
pass
|
||
elif not _guard_sub_commander_budget(state, "retry_count", "max_retries", "max_retries_exceeded"):
|
||
pass
|
||
elif not toolset:
|
||
if _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
response = await _invoke_llm(llm, working_messages)
|
||
state["final_response"] = _stringify_message_content(response.content)
|
||
elif capabilities.supports_native_tools:
|
||
state["tool_strategy_used"] = "native"
|
||
bound_llm = llm.bind_tools(toolset)
|
||
while state.get("final_response") is None and not state.get("clarification_needed"):
|
||
if not _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
break
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
response = await _invoke_llm(bound_llm, working_messages)
|
||
tool_calls = getattr(response, "tool_calls", None) or []
|
||
if tool_calls:
|
||
if not _guard_sub_commander_budget(state, "tool_round_count", "max_tool_rounds", "max_tool_rounds_exceeded"):
|
||
break
|
||
prepared_calls, clarification = _prepare_tool_calls_for_execution(tool_calls, state)
|
||
if clarification:
|
||
state["clarification_needed"] = True
|
||
state["clarification_question"] = clarification["question"]
|
||
_write_clarification_context(
|
||
state,
|
||
role=role,
|
||
sub_commander=sub_commander,
|
||
tool_name=clarification["tool_name"],
|
||
question=clarification["question"],
|
||
partial_args=clarification.get("partial_args"),
|
||
missing_fields=clarification.get("missing_fields"),
|
||
)
|
||
state["stop_reason"] = "clarification_needed"
|
||
state["final_response"] = clarification["question"]
|
||
break
|
||
state["tool_round_count"] = int(state.get("tool_round_count") or 0) + 1
|
||
assistant_tool_message = AIMessage(
|
||
content=_stringify_message_content(getattr(response, "content", "")),
|
||
tool_calls=tool_calls,
|
||
)
|
||
normalized_calls, tool_result, created_entities, tool_messages = await _execute_tool_calls(
|
||
prepared_calls,
|
||
toolset,
|
||
state,
|
||
)
|
||
state["tool_calls"] = normalized_calls
|
||
state["last_tool_result"] = tool_result
|
||
state["created_entities"] = [*(state.get("created_entities") or []), *created_entities]
|
||
if created_entities:
|
||
_clear_clarification_context(state)
|
||
if role == AgentRole.SCHEDULE_PLANNER and _should_clear_schedule_creation_continuity(state, created_entities):
|
||
_clear_structured_continuity(state)
|
||
working_messages = [*working_messages, assistant_tool_message, *tool_messages]
|
||
if sub_commander == "librarian_retrieval" and _is_missing_knowledge_result(tool_result):
|
||
working_messages.append(SystemMessage(content="如果检索工具没有找到证据,可以直接基于你的常识给出清晰回答,不要机械地说不知道。"))
|
||
continue
|
||
state["final_response"] = _stringify_message_content(response.content)
|
||
else:
|
||
state["tool_strategy_used"] = "json_fallback"
|
||
allowed_tools = [tool.name for tool in toolset]
|
||
while state.get("final_response") is None and not state.get("clarification_needed"):
|
||
if not _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
break
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
parsed = None
|
||
retry_instruction: BaseMessage | None = None
|
||
while parsed is None:
|
||
response = await _invoke_llm(
|
||
llm,
|
||
[
|
||
*working_messages,
|
||
SystemMessage(content=JSON_ACTION_FALLBACK_PROMPT),
|
||
SystemMessage(content=f"本次可用工具列表: {', '.join(allowed_tools)}"),
|
||
*([retry_instruction] if retry_instruction else []),
|
||
],
|
||
)
|
||
response_text = _stringify_message_content(response.content)
|
||
parsed = _parse_json_action(response_text, allowed_tools)
|
||
if parsed is None and response_text.strip() and state.get("tool_round_count"):
|
||
state["fallback_parse_error"] = None
|
||
state["final_response"] = response_text.strip()
|
||
break
|
||
if parsed is not None:
|
||
state["fallback_parse_error"] = None
|
||
break
|
||
if not _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
parsed = None
|
||
break
|
||
if int(state.get("retry_count") or 0) >= int(state.get("max_retries") or 0):
|
||
state["fallback_parse_error"] = "invalid_json_action"
|
||
state["final_response"] = "这次内部动作解析没整理好,不过您的意思我接住了。您再说一遍要我执行的内容,我只回结果,不展示内部调用细节。"
|
||
break
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
state["retry_count"] = int(state.get("retry_count") or 0) + 1
|
||
retry_instruction = SystemMessage(content="上一次输出不是有效 JSON。请严格只返回合法 JSON,不要加解释。")
|
||
if state.get("final_response") is not None:
|
||
break
|
||
if parsed is None:
|
||
break
|
||
if parsed["mode"] == "final":
|
||
state["final_response"] = parsed["final_response"]
|
||
break
|
||
if parsed["mode"] == "clarification":
|
||
state["clarification_needed"] = True
|
||
state["clarification_question"] = parsed["clarification_question"]
|
||
_write_clarification_context(
|
||
state,
|
||
role=role,
|
||
sub_commander=sub_commander,
|
||
tool_name="clarification",
|
||
question=parsed["clarification_question"],
|
||
)
|
||
state["stop_reason"] = "clarification_needed"
|
||
state["final_response"] = parsed["clarification_question"]
|
||
break
|
||
if not _guard_sub_commander_budget(state, "tool_round_count", "max_tool_rounds", "max_tool_rounds_exceeded"):
|
||
break
|
||
prepared_calls, clarification = _prepare_tool_calls_for_execution(parsed["tool_calls"], state)
|
||
if clarification:
|
||
state["clarification_needed"] = True
|
||
state["clarification_question"] = clarification["question"]
|
||
_write_clarification_context(
|
||
state,
|
||
role=role,
|
||
sub_commander=sub_commander,
|
||
tool_name=clarification["tool_name"],
|
||
question=clarification["question"],
|
||
partial_args=clarification.get("partial_args"),
|
||
missing_fields=clarification.get("missing_fields"),
|
||
)
|
||
state["stop_reason"] = "clarification_needed"
|
||
state["final_response"] = clarification["question"]
|
||
break
|
||
state["tool_round_count"] = int(state.get("tool_round_count") or 0) + 1
|
||
normalized_calls, tool_result, created_entities, tool_messages = await _execute_tool_calls(
|
||
prepared_calls,
|
||
toolset,
|
||
state,
|
||
)
|
||
state["tool_calls"] = normalized_calls
|
||
state["last_tool_result"] = tool_result
|
||
state["created_entities"] = [*(state.get("created_entities") or []), *created_entities]
|
||
if role == AgentRole.SCHEDULE_PLANNER and _should_clear_schedule_creation_continuity(state, created_entities):
|
||
_clear_structured_continuity(state)
|
||
working_messages = [*working_messages, *tool_messages]
|
||
if sub_commander == "librarian_retrieval" and _is_missing_knowledge_result(tool_result):
|
||
working_messages.append(SystemMessage(content="如果检索工具没有找到证据,可以直接基于你的常识给出清晰回答,不要机械地说不知道。"))
|
||
|
||
if summary_target:
|
||
state[_summary_state_key(summary_target)] = state.get("final_response")
|
||
|
||
task_result_summary = state.get("task_result_summary")
|
||
tool_outcomes = list(state.get("tool_outcomes") or [])[start_tool_index:]
|
||
has_tool_failure = any(
|
||
_tool_result_indicates_failure(outcome.get("result_preview"))
|
||
for outcome in tool_outcomes
|
||
)
|
||
verifier_input = {
|
||
"summary": state.get("final_response") or (task_result_summary or {}).get("tools"),
|
||
"evidence": tool_outcomes,
|
||
"success": bool(tool_outcomes or state.get("final_response")) and not has_tool_failure,
|
||
}
|
||
_append_event_trace(
|
||
state,
|
||
"agent.verify.started",
|
||
payload={
|
||
"summary_present": bool(verifier_input["summary"]),
|
||
"evidence_count": len(verifier_input["evidence"]),
|
||
},
|
||
)
|
||
verdict = verify_task_result(
|
||
summary=state.get("final_response"),
|
||
evidence=tool_outcomes,
|
||
result=verifier_input,
|
||
)
|
||
updated_state = apply_verification_verdict(state, verdict)
|
||
state.update(updated_state)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.verify.completed",
|
||
payload={
|
||
"status": verdict.status,
|
||
"summary": verdict.summary,
|
||
"evidence_count": len(verdict.evidence),
|
||
},
|
||
severity="error" if verdict.status == "failed" else "info",
|
||
)
|
||
|
||
final_response_text = state.get("final_response")
|
||
if not state.get("clarification_needed") and final_response_text:
|
||
_clear_clarification_context(state)
|
||
if (
|
||
role == AgentRole.SCHEDULE_PLANNER
|
||
and isinstance(final_response_text, str)
|
||
and _is_schedule_creation_confirmation_response(final_response_text)
|
||
):
|
||
_write_schedule_creation_continuity(state, user_query)
|
||
|
||
history_messages = list(state.get("messages", []))
|
||
final_response = state.get("final_response")
|
||
if isinstance(final_response, str):
|
||
history_messages.append(AIMessage(content=final_response))
|
||
|
||
state["messages"] = history_messages
|
||
state["should_respond"] = True
|
||
return state
|
||
|
||
|
||
def _derive_task_status(state: AgentState) -> Literal["completed", "failed", "blocked"]:
|
||
if state.get("clarification_needed") or state.get("stop_reason") == "clarification_needed":
|
||
return "blocked"
|
||
if state.get("verification_status") == "failed":
|
||
return "failed"
|
||
if state.get("stop_reason") and state.get("stop_reason") != "clarification_needed":
|
||
return "blocked"
|
||
if state.get("final_response"):
|
||
return "completed"
|
||
return "failed"
|
||
|
||
|
||
def _build_task_evidence(state: AgentState, start_index: int) -> list[dict[str, Any]]:
|
||
tool_outcomes = [dict(item) for item in list(state.get("tool_outcomes") or [])[start_index:]]
|
||
if tool_outcomes:
|
||
evidence = tool_outcomes
|
||
else:
|
||
evidence = []
|
||
|
||
if state.get("verification_status") or state.get("verification_summary"):
|
||
evidence.append(
|
||
{
|
||
"type": "verification",
|
||
"status": state.get("verification_status"),
|
||
"summary": state.get("verification_summary"),
|
||
}
|
||
)
|
||
if not evidence and state.get("final_response"):
|
||
evidence.append(
|
||
{
|
||
"type": "response_summary",
|
||
"content": _stringify_message_content(state.get("final_response"))[:400],
|
||
}
|
||
)
|
||
return evidence
|
||
|
||
|
||
def _collect_task_result(task: AgentTask, state: AgentState, start_tool_index: int) -> TaskResult:
|
||
status = _derive_task_status(state)
|
||
next_action = state.get("clarification_question")
|
||
if not next_action and status != "completed" and state.get("stop_reason"):
|
||
next_action = f"resolve_{state['stop_reason']}"
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
status=status,
|
||
summary=_stringify_message_content(state.get("final_response") or state.get("verification_summary")),
|
||
evidence=_build_task_evidence(state, start_tool_index),
|
||
owner_agent_id=task.owner_agent_id,
|
||
parent_task_id=task.parent_task_id,
|
||
child_task_ids=list(task.child_task_ids),
|
||
thread_id=str(state.get("thread_id") or "") or None,
|
||
message_id=str(state.get("last_message_id") or "") or None,
|
||
message_index=int(state.get("message_sequence") or 0) or None,
|
||
interrupt_records=list(state.get("interrupted_tasks") or []),
|
||
recovery_records=list(state.get("recovery_trace") or []),
|
||
budget_snapshot=state.get("budget_state"),
|
||
next_action=next_action,
|
||
output_data={
|
||
"role": task.role,
|
||
"sub_commander": state.get("current_sub_commander"),
|
||
"verification_status": state.get("verification_status"),
|
||
},
|
||
)
|
||
|
||
|
||
def _apply_task_result_to_state(state: AgentState, task: AgentTask, task_result: TaskResult) -> None:
|
||
normalized_result = normalize_task_result(task_result, default_task_id=task.task_id)
|
||
serialized_result = normalized_result.model_dump(mode="json")
|
||
state["task_results"] = [*(state.get("task_results") or []), serialized_result]
|
||
|
||
updated_tasks: list[dict[str, Any]] = []
|
||
completed_entry: dict[str, Any] | None = None
|
||
pending_tasks: list[dict[str, Any]] = []
|
||
for existing_task in state.get("active_tasks") or []:
|
||
normalized_task = existing_task.model_dump(mode="json") if isinstance(existing_task, AgentTask) else dict(existing_task)
|
||
if normalized_task.get("task_id") == task.task_id:
|
||
normalized_task["status"] = normalized_result.status
|
||
normalized_task["evidence"] = list(normalized_result.evidence)
|
||
normalized_task["result_summary"] = normalized_result.summary
|
||
if normalized_result.next_action:
|
||
normalized_task["next_action"] = normalized_result.next_action
|
||
completed_entry = {
|
||
"task_id": task.task_id,
|
||
"role": task.role,
|
||
"owner_agent_id": task.owner_agent_id,
|
||
"status": normalized_result.status,
|
||
"summary": normalized_result.summary,
|
||
"next_action": normalized_result.next_action,
|
||
}
|
||
updated_tasks.append(normalized_task)
|
||
if normalized_task.get("status") != "completed":
|
||
pending_tasks.append(normalized_task)
|
||
state["active_tasks"] = updated_tasks
|
||
state["pending_tasks"] = pending_tasks
|
||
|
||
if completed_entry is not None:
|
||
state["completed_tasks"] = [*(state.get("completed_tasks") or []), completed_entry]
|
||
|
||
|
||
def _build_collaboration_final_response(task_results: list[TaskResult | dict[str, Any]]) -> str:
|
||
normalized_results = [
|
||
normalize_task_result(item)
|
||
for item in task_results
|
||
]
|
||
lines = [f"已按协作模式回收 {len(normalized_results)} 个子任务结果:"]
|
||
final_completed_summary: str | None = None
|
||
for index, result in enumerate(normalized_results, start=1):
|
||
owner = result.owner_agent_id or "unknown"
|
||
summary = result.summary or "未返回摘要。"
|
||
lines.append(f"{index}. [{owner}] ({result.status}) {summary}")
|
||
if result.status == "completed" and result.summary:
|
||
final_completed_summary = result.summary
|
||
if final_completed_summary:
|
||
lines.append(f"汇总结论:{final_completed_summary}")
|
||
elif normalized_results:
|
||
blocked_result = next((item for item in reversed(normalized_results) if item.next_action), None)
|
||
if blocked_result and blocked_result.next_action:
|
||
lines.append(f"下一步:{blocked_result.next_action}")
|
||
else:
|
||
lines.append("汇总结论:协作任务已执行,但最终摘要为空。")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _verify_collaboration_results(
|
||
state: AgentState,
|
||
tasks: list[AgentTask],
|
||
task_results: list[TaskResult | dict[str, Any]] | None = None,
|
||
) -> None:
|
||
expected_task_ids = {task.task_id for task in tasks}
|
||
normalized_results = [
|
||
normalize_task_result(item)
|
||
for item in (task_results if task_results is not None else (state.get("task_results") or []))
|
||
if normalize_task_result(item).task_id in expected_task_ids
|
||
]
|
||
result_by_task_id = {item.task_id: item for item in normalized_results}
|
||
missing_task_ids = [task.task_id for task in tasks if task.task_id not in result_by_task_id]
|
||
failed_or_blocked = [
|
||
item.task_id
|
||
for item in normalized_results
|
||
if item.status in {"failed", "blocked"}
|
||
]
|
||
missing_evidence = [
|
||
item.task_id
|
||
for item in normalized_results
|
||
if not item.evidence
|
||
]
|
||
|
||
verification_evidence = [
|
||
{
|
||
"task_id": item.task_id,
|
||
"owner_agent_id": item.owner_agent_id,
|
||
"status": item.status,
|
||
"evidence_count": len(item.evidence or []),
|
||
}
|
||
for item in normalized_results
|
||
]
|
||
if missing_task_ids:
|
||
summary = f"协作结果不完整,缺少任务结果: {', '.join(missing_task_ids)}"
|
||
verdict = verify_task_result(status="failed", summary=summary, evidence=verification_evidence)
|
||
elif failed_or_blocked:
|
||
summary = f"协作结果未闭环,存在失败或阻塞任务: {', '.join(failed_or_blocked)}"
|
||
verdict = verify_task_result(status="failed", summary=summary, evidence=verification_evidence)
|
||
elif missing_evidence:
|
||
summary = f"协作结果证据不足,缺少 evidence 的任务: {', '.join(missing_evidence)}"
|
||
verdict = verify_task_result(status="failed", summary=summary, evidence=verification_evidence)
|
||
else:
|
||
summary = f"协作模式已完成 {len(normalized_results)}/{len(tasks)} 个子任务,并为每个子任务回收了结果与 evidence。"
|
||
verdict = verify_task_result(status="passed", summary=summary, evidence=verification_evidence)
|
||
|
||
updated_state = apply_verification_verdict(state, verdict)
|
||
state.update(updated_state)
|
||
|
||
|
||
async def _run_collaboration_flow(state: AgentState, user_query: str) -> AgentState:
|
||
tasks = _build_collaboration_tasks(user_query)
|
||
if len(tasks) < 2:
|
||
state["execution_mode"] = "direct"
|
||
state["routing_decision"] = {"mode": "direct", "reason": "collaboration_plan_fell_back"}
|
||
return state
|
||
|
||
base_history = list(state.get("messages", []))
|
||
root_agent_id = str(state.get("root_agent_id") or state.get("agent_id") or AgentRole.MASTER.value)
|
||
coordinator_agent_id = str(state.get("agent_id") or AgentRole.MASTER.value)
|
||
state["execution_mode"] = "collaboration"
|
||
state["routing_decision"] = {
|
||
"mode": "collaboration",
|
||
"reason": "multi_role_request",
|
||
"task_count": len(tasks),
|
||
"roles": [task.role for task in tasks],
|
||
}
|
||
budget_snapshot = _budget_snapshot_from_state(
|
||
state,
|
||
mode="collaboration",
|
||
remaining_parallel_tasks=len(tasks),
|
||
metadata={"task_count": len(tasks)},
|
||
)
|
||
_store_budget_snapshot(state, budget_snapshot)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.collaboration.budget.updated",
|
||
payload=budget_snapshot,
|
||
)
|
||
state["active_tasks"] = [task.model_dump(mode="json") for task in tasks]
|
||
parent_task_id = next((task.parent_task_id for task in tasks if task.parent_task_id), None) or "root"
|
||
state["task_hierarchy"] = {parent_task_id: [task.task_id for task in tasks]}
|
||
state["task_results"] = []
|
||
state["next_step"] = None
|
||
|
||
for task in tasks:
|
||
state["current_agent"] = AgentRole.MASTER.value
|
||
state["agent_id"] = coordinator_agent_id
|
||
state["parent_agent_id"] = None
|
||
state["root_agent_id"] = root_agent_id
|
||
state["collaboration_depth"] = 0
|
||
assigned_role = _assign_agent_for_task_role(task.role)
|
||
child_agent_id = _create_child_agent(state, role=assigned_role, task=task)
|
||
if child_agent_id is None:
|
||
interrupt = _record_interrupt(
|
||
state,
|
||
reason="spawn_blocked",
|
||
task=task,
|
||
payload={"target_role": assigned_role.value},
|
||
)
|
||
_record_recovery(
|
||
state,
|
||
interrupt=interrupt,
|
||
strategy="fallback_to_direct_role_execution",
|
||
task=task,
|
||
payload={"target_role": assigned_role.value},
|
||
)
|
||
child_agent_id = f"blocked-{assigned_role.value}"
|
||
state["messages"] = list(base_history)
|
||
state["current_agent"] = assigned_role.value
|
||
state["agent_id"] = child_agent_id
|
||
state["parent_agent_id"] = coordinator_agent_id
|
||
state["root_agent_id"] = root_agent_id
|
||
state["collaboration_depth"] = 1
|
||
task.thread_id = str(state.get("thread_id") or "") or None
|
||
task.message_id = str(state.get("last_message_id") or "") or None
|
||
task.message_index = int(state.get("message_sequence") or 0) or None
|
||
task.collaboration_budget = state.get("budget_state")
|
||
state["turn_context"] = {"collaboration_task": task.model_dump(mode="json")}
|
||
state["final_response"] = None
|
||
state["verification_status"] = None
|
||
state["verification_summary"] = None
|
||
state["verification_evidence"] = []
|
||
state["clarification_needed"] = False
|
||
state["clarification_question"] = None
|
||
state["stop_reason"] = None
|
||
start_tool_index = len(state.get("tool_outcomes") or [])
|
||
|
||
await _run_sub_commander(
|
||
state,
|
||
assigned_role,
|
||
ROLE_SYSTEM_PROMPTS[assigned_role],
|
||
task.goal or user_query,
|
||
use_tools=True,
|
||
summary_target=None,
|
||
)
|
||
|
||
task_result = _collect_task_result(task, state, start_tool_index)
|
||
_append_message_trace(
|
||
state,
|
||
from_agent_id=child_agent_id,
|
||
to_agent_id=coordinator_agent_id,
|
||
message_type="task_update",
|
||
content_summary=task_result.summary or task.title,
|
||
task_id=task.task_id,
|
||
reply_to_message_id=task.message_id,
|
||
payload={"status": task_result.status, "owner_agent_id": task_result.owner_agent_id},
|
||
)
|
||
_append_event_trace(
|
||
state,
|
||
"agent.message.sent",
|
||
payload={"status": task_result.status, "summary": (task_result.summary or "")[:200]},
|
||
task_id=task.task_id,
|
||
parent_task_id=task.parent_task_id,
|
||
child_task_id=(task.child_task_ids or [None])[0],
|
||
message_id=str(state.get("last_message_id") or "") or None,
|
||
)
|
||
_apply_task_result_to_state(state, task, task_result)
|
||
|
||
if task_result.status != "completed":
|
||
break
|
||
|
||
state["turn_context"] = None
|
||
state["current_agent"] = AgentRole.MASTER.value
|
||
state["agent_id"] = coordinator_agent_id
|
||
state["parent_agent_id"] = None
|
||
state["root_agent_id"] = root_agent_id
|
||
state["collaboration_depth"] = 0
|
||
state["final_response"] = _build_collaboration_final_response(state.get("task_results") or [])
|
||
_append_event_trace(
|
||
state,
|
||
"agent.verify.started",
|
||
payload={
|
||
"summary_present": bool(state["final_response"]),
|
||
"evidence_count": len(state.get("task_results") or []),
|
||
},
|
||
)
|
||
_verify_collaboration_results(state, tasks, state.get("task_results") or [])
|
||
_append_event_trace(
|
||
state,
|
||
"agent.verify.completed",
|
||
payload={
|
||
"status": state.get("verification_status"),
|
||
"summary": state.get("verification_summary"),
|
||
"evidence_count": len(state.get("verification_evidence") or []),
|
||
},
|
||
severity="error" if state.get("verification_status") == "failed" else "info",
|
||
)
|
||
state["messages"] = [*base_history, AIMessage(content=state["final_response"])]
|
||
state["should_respond"] = True
|
||
return state
|
||
|
||
|
||
def _can_delegate_within_hop_budget(state: AgentState) -> bool:
|
||
return _get_state_int(state, "routing_hops") < _get_state_int(state, "max_routing_hops")
|
||
|
||
|
||
def _stop_due_to_loop_guard(state: AgentState) -> AgentState:
|
||
state["terminated_due_to_loop_guard"] = True
|
||
state["final_response"] = "这次需要处理的步骤有点多,我先停在这里。您可以把目标再明确一点,或让我先只完成其中一步。"
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
|
||
async def master_node(state: AgentState) -> AgentState:
|
||
_maybe_reset_turn_budgets(state)
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content).strip() if user_messages else ""
|
||
|
||
state["current_agent"] = state.get("current_agent") or AgentRole.MASTER
|
||
state["active_agents"] = _normalize_active_agents(state.get("active_agents"))
|
||
|
||
if _is_simple_greeting(user_query):
|
||
state["final_response"] = "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。"
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
if _is_identity_question(user_query):
|
||
state["final_response"] = (
|
||
"我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:"
|
||
"帮您看清问题、压缩路径、把事情往前推进。"
|
||
)
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
if _is_capability_question(user_query):
|
||
state["final_response"] = (
|
||
"主要做三件事。\n"
|
||
"- 帮您判断:看问题本质、梳理取舍、给出方向\n"
|
||
"- 帮您收束:把复杂内容理顺,把重点拎出来\n"
|
||
"- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n"
|
||
"如果您现在有具体目标,我可以直接进入处理。"
|
||
)
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
state["current_agent"] = _normalize_current_agent(state.get("current_agent"))
|
||
|
||
structured_continuity_route = _route_from_structured_continuity(state, user_query)
|
||
clarification_route = _route_from_clarification_context(state, user_query)
|
||
if structured_continuity_route is not None:
|
||
state["execution_mode"] = "direct"
|
||
routed_agent = structured_continuity_route
|
||
elif clarification_route is not None:
|
||
state["execution_mode"] = "direct"
|
||
routed_agent = clarification_route
|
||
elif _is_short_confirmation(user_query) and _previous_turn_proposed_schedule_creation(state.get("messages", [])):
|
||
state["execution_mode"] = "direct"
|
||
routed_agent = AgentRole.SCHEDULE_PLANNER
|
||
else:
|
||
request_mode, routing_metadata = _select_request_mode(user_query)
|
||
state["routing_decision"] = routing_metadata
|
||
if request_mode == "collaboration":
|
||
collaboration_state = await _run_collaboration_flow(state, user_query)
|
||
if collaboration_state.get("execution_mode") == "collaboration" or collaboration_state.get("final_response"):
|
||
return collaboration_state
|
||
state["execution_mode"] = "direct"
|
||
routed_agent = _route_agent_from_user_query(user_query)
|
||
if routed_agent != AgentRole.MASTER:
|
||
if not _can_delegate_within_hop_budget(state):
|
||
return _stop_due_to_loop_guard(state)
|
||
state["routing_hops"] = int(state.get("routing_hops") or 0) + 1
|
||
state["current_agent"] = routed_agent.value
|
||
state["next_step"] = routed_agent.value
|
||
if routed_agent not in state["active_agents"]:
|
||
state["active_agents"] = [*state["active_agents"], routed_agent]
|
||
state["agent_trace"] = [*(state.get("agent_trace") or [AgentRole.MASTER.value]), routed_agent.value]
|
||
return state
|
||
|
||
llm = _get_llm_for_state(state)
|
||
response = await _invoke_llm(llm, [SystemMessage(content=MASTER_SYSTEM_PROMPT), *state["messages"]])
|
||
content = _stringify_message_content(response.content).strip()
|
||
|
||
routed_agent = _route_agent_from_user_query(content)
|
||
if routed_agent != AgentRole.MASTER and len(content) <= 64:
|
||
if not _can_delegate_within_hop_budget(state):
|
||
return _stop_due_to_loop_guard(state)
|
||
state["routing_hops"] = int(state.get("routing_hops") or 0) + 1
|
||
state["current_agent"] = routed_agent.value
|
||
state["next_step"] = routed_agent.value
|
||
if routed_agent not in state["active_agents"]:
|
||
state["active_agents"] = [*state["active_agents"], routed_agent]
|
||
state["agent_trace"] = [*(state.get("agent_trace") or [AgentRole.MASTER.value]), routed_agent.value]
|
||
return state
|
||
|
||
state["final_response"] = content
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=content)]
|
||
return state
|
||
|
||
|
||
async def planner_node(state: AgentState) -> AgentState:
|
||
state["next_step"] = None
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
if _has_active_clarification_context(state):
|
||
user_query = _build_resumed_clarification_query(state, user_query)
|
||
elif _is_short_confirmation(user_query) and _previous_turn_proposed_schedule_creation(state.get("messages", [])):
|
||
user_query = _expand_schedule_confirmation_query(user_query, state.get("messages", []))
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.SCHEDULE_PLANNER,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.SCHEDULE_PLANNER],
|
||
user_query,
|
||
use_tools=True,
|
||
summary_target="schedule_context_summary",
|
||
)
|
||
|
||
|
||
async def executor_node(state: AgentState) -> AgentState:
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.EXECUTOR,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.EXECUTOR],
|
||
user_query,
|
||
use_tools=True,
|
||
)
|
||
|
||
|
||
async def librarian_node(state: AgentState) -> AgentState:
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.LIBRARIAN,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.LIBRARIAN],
|
||
user_query,
|
||
use_tools=True,
|
||
summary_target="knowledge_context",
|
||
)
|
||
|
||
|
||
async def analyst_node(state: AgentState) -> AgentState:
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.ANALYST,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.ANALYST],
|
||
user_query,
|
||
use_tools=True,
|
||
summary_target="analysis_report",
|
||
)
|
||
|
||
|
||
def route_agent(state: AgentState) -> str:
|
||
if state.get("final_response"):
|
||
return END
|
||
next_step = _role_value(state.get("next_step"))
|
||
if next_step in _role_values():
|
||
return next_step
|
||
return _role_value(state.get("current_agent") or AgentRole.MASTER)
|
||
|
||
|
||
def _compile_graph(graph: StateGraph, callbacks: list | None = None):
|
||
if callbacks:
|
||
try:
|
||
return graph.compile(callbacks=callbacks)
|
||
except TypeError as exc:
|
||
if "callbacks" not in str(exc):
|
||
raise
|
||
return graph.compile()
|
||
|
||
|
||
def create_agent_graph(callbacks: list | None = None):
|
||
graph = StateGraph(AgentState)
|
||
|
||
graph.add_node(AgentRole.MASTER.value, master_node)
|
||
graph.add_node(AgentRole.SCHEDULE_PLANNER.value, planner_node)
|
||
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||
graph.add_node(AgentRole.ANALYST.value, analyst_node)
|
||
|
||
graph.set_entry_point(AgentRole.MASTER.value)
|
||
graph.add_conditional_edges(
|
||
AgentRole.MASTER.value,
|
||
route_agent,
|
||
{
|
||
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
|
||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||
END: END,
|
||
},
|
||
)
|
||
|
||
for role in (
|
||
AgentRole.SCHEDULE_PLANNER,
|
||
AgentRole.EXECUTOR,
|
||
AgentRole.LIBRARIAN,
|
||
AgentRole.ANALYST,
|
||
):
|
||
graph.add_edge(role.value, END)
|
||
|
||
return _compile_graph(graph, callbacks=callbacks)
|
||
|
||
|
||
_agent_graph = None
|
||
|
||
|
||
def get_agent_graph(callbacks: list | None = None):
|
||
global _agent_graph
|
||
if _agent_graph is None:
|
||
from app.config_tracing import get_langsmith_callbacks
|
||
|
||
langsmith_callbacks = get_langsmith_callbacks()
|
||
all_callbacks = (callbacks or []) + langsmith_callbacks
|
||
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
|
||
return _agent_graph
|
||
|
||
|
||
__all__ = [
|
||
"_build_collaboration_tasks",
|
||
"_build_verifier_hints",
|
||
"_choose_sub_commander",
|
||
"_parse_json_action",
|
||
"_route_agent_from_user_query",
|
||
"_select_request_mode",
|
||
"_run_collaboration_flow",
|
||
"_run_sub_commander",
|
||
"create_agent_graph",
|
||
"get_agent_graph",
|
||
"master_node",
|
||
]
|