Align the L3 graph, agent service, and sync tool shims on one canonical continuity contract so clarification resumes and persisted snapshots behave consistently. Add targeted regressions and hardening notes covering system-message coalescing, async bridge usage, and continuity rehydration.
1366 lines
55 KiB
Python
1366 lines
55 KiB
Python
"""Jarvis agent graph orchestration."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Any, Literal, cast
|
||
|
||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||
from langgraph.graph import END, StateGraph
|
||
|
||
from app.agents.prompts import (
|
||
ANALYST_SYSTEM_PROMPT,
|
||
EXECUTOR_SYSTEM_PROMPT,
|
||
JSON_ACTION_FALLBACK_PROMPT,
|
||
LIBRARIAN_SYSTEM_PROMPT,
|
||
MASTER_SYSTEM_PROMPT,
|
||
SCHEDULE_PLANNER_SYSTEM_PROMPT,
|
||
)
|
||
from app.agents.skill_registry import build_skill_context
|
||
from app.agents.state import AgentRole, AgentState
|
||
from app.agents.tools import SUB_COMMANDER_TOOLSETS
|
||
from app.agents.tools.time_reasoning import normalize_tool_time_arguments
|
||
from app.services.llm_service import (
|
||
create_llm_from_config,
|
||
default_provider_capabilities,
|
||
get_llm,
|
||
resolve_provider_capabilities,
|
||
)
|
||
|
||
logger = logging.getLogger("jarvis.agent")
|
||
|
||
SUB_COMMANDER_PROMPTS = {
|
||
"schedule_analysis": "你负责做日程判断与聚焦分析,只给判断和依据,不直接落库。",
|
||
"schedule_planning": "你负责把当前目标转成明确安排;当用户要求创建提醒/任务/目标时,直接调用工具执行。",
|
||
"executor_tasks": "你负责执行任务、提醒、目标、待办相关操作,需要时直接调用工具。",
|
||
"executor_forum": "你只处理论坛与指令帖相关操作。",
|
||
"librarian_retrieval": "你负责知识检索与外部搜索,基于证据回答问题。",
|
||
"librarian_graph": "你负责知识图谱上下文与结构化沉淀。",
|
||
"analyst_progress": "你负责进度研判,汇总当前推进情况。",
|
||
"analyst_insights": "你负责趋势、风险和机会判断,必要时调用检索工具。",
|
||
}
|
||
|
||
ROLE_SKILL_CONTEXT = {
|
||
AgentRole.SCHEDULE_PLANNER: "schedule_planner",
|
||
AgentRole.EXECUTOR: "executor",
|
||
AgentRole.LIBRARIAN: "librarian",
|
||
AgentRole.ANALYST: "analyst",
|
||
}
|
||
|
||
ROLE_SYSTEM_PROMPTS = {
|
||
AgentRole.SCHEDULE_PLANNER: SCHEDULE_PLANNER_SYSTEM_PROMPT,
|
||
AgentRole.EXECUTOR: EXECUTOR_SYSTEM_PROMPT,
|
||
AgentRole.LIBRARIAN: LIBRARIAN_SYSTEM_PROMPT,
|
||
AgentRole.ANALYST: ANALYST_SYSTEM_PROMPT,
|
||
}
|
||
|
||
SCHEDULE_KEYWORDS = (
|
||
"提醒",
|
||
"日程",
|
||
"安排",
|
||
"计划",
|
||
"排期",
|
||
"会议",
|
||
"开会",
|
||
"明天",
|
||
"今天",
|
||
"后天",
|
||
"下周",
|
||
"本周",
|
||
"周",
|
||
"星期",
|
||
"交付",
|
||
"节点",
|
||
"deadline",
|
||
)
|
||
ACCOUNTING_INTENT_KEYWORDS = (
|
||
"记账",
|
||
"账单",
|
||
"花了多少钱",
|
||
"用了多少钱",
|
||
"支出",
|
||
"消费",
|
||
"花销",
|
||
"开销",
|
||
)
|
||
KNOWLEDGE_KEYWORDS = ("知识", "搜索", "检索", "资料", "文档", "联网", "上网", "查询", "查一下", "最新")
|
||
GENERAL_QA_PATTERNS = (
|
||
"介绍一下",
|
||
"介绍下",
|
||
"什么是",
|
||
"是谁",
|
||
"在哪里",
|
||
"为什么",
|
||
"怎么理解",
|
||
"聊聊",
|
||
)
|
||
ANALYSIS_KEYWORDS = ("分析", "报告", "统计", "趋势", "风险", "洞察", "总结")
|
||
EXECUTION_KEYWORDS = ("创建", "更新", "修改", "执行", "发帖", "论坛", "帖子", "完成", "处理")
|
||
SCHEDULE_ANALYSIS_KEYWORDS = ("聚焦", "判断", "分析", "优先级", "取舍", "最近对话", "该做什么")
|
||
SCHEDULE_PLANNING_KEYWORDS = ("安排", "计划", "排期", "提醒", "创建", "新增", "会议", "交付", "节点")
|
||
IDENTITY_PATTERNS = ("你是谁", "你是誰")
|
||
CAPABILITY_PATTERNS = ("你能做什么", "你可以做什么", "你会做什么")
|
||
SHORT_CONFIRMATION_PATTERNS = ("创建", "好的创建", "确认创建", "就创建", "那就创建")
|
||
SCHEDULE_CONFIRMATION_HINTS = (
|
||
"创建这条提醒",
|
||
"现在创建这条提醒",
|
||
"现在创建提醒",
|
||
"是否需要我现在创建",
|
||
)
|
||
SCHEDULE_CONFIRMATION_QUESTION_MARKERS = ("是否", "要不要", "吗", "?", "?")
|
||
|
||
|
||
def _role_value(role: AgentRole | str | None) -> str:
|
||
if isinstance(role, AgentRole):
|
||
return role.value
|
||
return str(role or "")
|
||
|
||
|
||
def _normalize_current_agent(value: AgentRole | str | None) -> str:
|
||
role_value = _role_value(value)
|
||
return role_value or AgentRole.MASTER.value
|
||
|
||
|
||
def _normalize_active_agents(values: list[AgentRole | str] | None) -> list[AgentRole]:
|
||
normalized: list[AgentRole] = []
|
||
for value in values or [AgentRole.MASTER]:
|
||
role_value = _role_value(value)
|
||
try:
|
||
role = AgentRole(role_value)
|
||
except ValueError:
|
||
continue
|
||
if role not in normalized:
|
||
normalized.append(role)
|
||
return normalized or [AgentRole.MASTER]
|
||
|
||
|
||
def _normalize_user_text(text: str) -> str:
|
||
normalized = (text or "").strip().lower()
|
||
normalized = re.sub(r"[,。!?;:,.!?;:\s]+", "", normalized)
|
||
return normalized
|
||
|
||
|
||
def _stringify_message_content(content: Any) -> str:
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for item in content:
|
||
if isinstance(item, str):
|
||
parts.append(item)
|
||
continue
|
||
if isinstance(item, dict):
|
||
text = item.get("text")
|
||
if text:
|
||
parts.append(str(text))
|
||
continue
|
||
nested = item.get("content")
|
||
if nested:
|
||
parts.append(_stringify_message_content(nested))
|
||
continue
|
||
parts.append(str(item))
|
||
return "".join(parts)
|
||
if isinstance(content, dict):
|
||
text = content.get("text")
|
||
if text:
|
||
return str(text)
|
||
nested = content.get("content")
|
||
if nested:
|
||
return _stringify_message_content(nested)
|
||
return json.dumps(content, ensure_ascii=False)
|
||
return str(content)
|
||
|
||
|
||
def _get_state_int(state: AgentState, key: str) -> int:
|
||
value = state.get(key)
|
||
return value if isinstance(value, int) else 0
|
||
|
||
|
||
def _role_values() -> set[str]:
|
||
return {role.value for role in AgentRole}
|
||
|
||
|
||
def _summary_state_key(target: str) -> Literal["schedule_context_summary", "knowledge_context", "analysis_report"]:
|
||
if target not in {"schedule_context_summary", "knowledge_context", "analysis_report"}:
|
||
raise ValueError(f"unsupported summary target: {target}")
|
||
return cast(Literal["schedule_context_summary", "knowledge_context", "analysis_report"], target)
|
||
|
||
|
||
def _get_llm_for_state(state: AgentState):
|
||
user_llm_config = state.get("user_llm_config")
|
||
llm = create_llm_from_config(user_llm_config) if user_llm_config else get_llm()
|
||
capabilities = getattr(llm, "_jarvis_provider_capabilities", None)
|
||
if capabilities is None:
|
||
capabilities = (
|
||
resolve_provider_capabilities(user_llm_config)
|
||
if user_llm_config
|
||
else default_provider_capabilities()
|
||
)
|
||
state["provider_capabilities"] = {
|
||
"provider": capabilities.provider,
|
||
"supports_native_tools": capabilities.supports_native_tools,
|
||
"preferred_tool_strategy": capabilities.preferred_tool_strategy,
|
||
}
|
||
return llm
|
||
|
||
|
||
def _resolve_capabilities(state: AgentState, llm) -> Any:
|
||
capabilities = getattr(llm, "_jarvis_provider_capabilities", None)
|
||
if capabilities is None:
|
||
config = state.get("user_llm_config")
|
||
capabilities = resolve_provider_capabilities(config) if config else default_provider_capabilities()
|
||
state["provider_capabilities"] = {
|
||
"provider": capabilities.provider,
|
||
"supports_native_tools": capabilities.supports_native_tools,
|
||
"preferred_tool_strategy": capabilities.preferred_tool_strategy,
|
||
}
|
||
return capabilities
|
||
|
||
|
||
def _filter_user_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||
return [message for message in messages if getattr(message, "type", "") in {"human", "user"}]
|
||
|
||
|
||
def _coalesce_system_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||
system_parts: list[str] = []
|
||
non_system_messages: list[BaseMessage] = []
|
||
|
||
for message in messages:
|
||
if getattr(message, "type", "") == "system":
|
||
text = _stringify_message_content(getattr(message, "content", ""))
|
||
if text:
|
||
system_parts.append(text)
|
||
continue
|
||
non_system_messages.append(message)
|
||
|
||
if not system_parts:
|
||
return non_system_messages
|
||
|
||
return [SystemMessage(content="\n\n".join(system_parts)), *non_system_messages]
|
||
|
||
|
||
def _is_simple_greeting(text: str) -> bool:
|
||
return _normalize_user_text(text) in {"你好", "您好", "早", "早上好", "在吗", "嗨", "hi", "hello"}
|
||
|
||
|
||
def _is_identity_question(text: str) -> bool:
|
||
normalized = _normalize_user_text(text)
|
||
return any(normalized.startswith(pattern) for pattern in IDENTITY_PATTERNS)
|
||
|
||
|
||
def _is_capability_question(text: str) -> bool:
|
||
normalized = _normalize_user_text(text)
|
||
return any(normalized.startswith(pattern) for pattern in CAPABILITY_PATTERNS)
|
||
|
||
|
||
def _is_short_confirmation(text: str) -> bool:
|
||
normalized = _normalize_user_text(text)
|
||
return normalized in SHORT_CONFIRMATION_PATTERNS
|
||
|
||
|
||
def _tool_result_indicates_failure(result: Any) -> bool:
|
||
text = _stringify_message_content(result)
|
||
return "失败" in text
|
||
|
||
|
||
def _latest_assistant_message_content(messages: list[BaseMessage]) -> str:
|
||
previous_assistant_message = next(
|
||
(
|
||
message
|
||
for message in reversed(messages[:-1])
|
||
if getattr(message, "type", "") == "ai"
|
||
),
|
||
None,
|
||
)
|
||
if previous_assistant_message is None:
|
||
return ""
|
||
return _stringify_message_content(getattr(previous_assistant_message, "content", ""))
|
||
|
||
|
||
def _previous_turn_proposed_schedule_creation(messages: list[BaseMessage]) -> bool:
|
||
if len(messages) < 2:
|
||
return False
|
||
content = _latest_assistant_message_content(messages)
|
||
has_schedule_confirmation_hint = any(hint in content for hint in SCHEDULE_CONFIRMATION_HINTS)
|
||
has_question_marker = any(marker in content for marker in SCHEDULE_CONFIRMATION_QUESTION_MARKERS)
|
||
return has_schedule_confirmation_hint and has_question_marker
|
||
|
||
|
||
def _previous_turn_completed_reminder_creation(state: AgentState) -> bool:
|
||
created_entities = state.get("created_entities") or []
|
||
if not created_entities:
|
||
return False
|
||
latest_entity = created_entities[-1]
|
||
if latest_entity.get("type") != "reminder":
|
||
return False
|
||
tool_calls = state.get("tool_calls") or []
|
||
if not tool_calls:
|
||
return False
|
||
latest_tool_call = tool_calls[-1]
|
||
if latest_tool_call.get("name") != "create_reminder":
|
||
return False
|
||
messages = state.get("messages") or []
|
||
if len(messages) < 2:
|
||
return False
|
||
if getattr(messages[-1], "type", "") not in {"human", "user"}:
|
||
return False
|
||
previous_message = messages[-2]
|
||
if getattr(previous_message, "type", "") != "ai":
|
||
return False
|
||
previous_assistant_content = _stringify_message_content(getattr(previous_message, "content", ""))
|
||
completion_markers = ("已创建提醒", "提醒已经创建好了", "帮你设好了这条提醒", "创建成功")
|
||
return any(marker in previous_assistant_content for marker in completion_markers)
|
||
|
||
|
||
def _latest_non_confirmation_user_request(messages: list[BaseMessage]) -> str | None:
|
||
user_messages = [
|
||
_stringify_message_content(getattr(message, "content", "")).strip()
|
||
for message in messages
|
||
if getattr(message, "type", "") in {"human", "user"}
|
||
]
|
||
for content in reversed(user_messages[:-1]):
|
||
if content and not _is_short_confirmation(content):
|
||
return content
|
||
return None
|
||
|
||
|
||
def _expand_schedule_confirmation_query(user_query: str, messages: list[BaseMessage]) -> str:
|
||
previous_request = _latest_non_confirmation_user_request(messages)
|
||
if not previous_request:
|
||
return user_query
|
||
return f"用户确认继续创建上一条提醒安排:{previous_request}"
|
||
|
||
|
||
def _is_schedule_creation_confirmation_response(response_text: str) -> bool:
|
||
content = _stringify_message_content(response_text)
|
||
has_schedule_confirmation_hint = any(hint in content for hint in SCHEDULE_CONFIRMATION_HINTS)
|
||
has_question_marker = any(marker in content for marker in SCHEDULE_CONFIRMATION_QUESTION_MARKERS)
|
||
return has_schedule_confirmation_hint and has_question_marker
|
||
|
||
|
||
def _write_schedule_creation_continuity(state: AgentState, user_query: str) -> None:
|
||
summary = user_query.strip()
|
||
if not summary:
|
||
return
|
||
state["pending_action"] = {
|
||
"type": "schedule_creation",
|
||
"summary": summary,
|
||
"status": "pending",
|
||
}
|
||
state["routing_decision"] = {
|
||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||
"reason": "continue_pending_action",
|
||
}
|
||
state["continuity_state"] = {"status": "fresh"}
|
||
|
||
|
||
def _clear_structured_continuity(state: AgentState) -> None:
|
||
state["pending_action"] = None
|
||
state["routing_decision"] = None
|
||
state["continuity_state"] = None
|
||
|
||
|
||
def _should_clear_schedule_creation_continuity(state: AgentState, created_entities: list[dict[str, Any]]) -> bool:
|
||
if not _has_active_structured_continuation(state):
|
||
return False
|
||
pending_action = state.get("pending_action") or {}
|
||
if pending_action.get("type") != "schedule_creation":
|
||
return False
|
||
return any(entity.get("type") == "reminder" for entity in created_entities)
|
||
|
||
|
||
def _route_agent_from_user_query(user_query: str) -> AgentRole:
|
||
text = (user_query or "").strip().lower()
|
||
|
||
has_accounting_signal = any(keyword in text for keyword in ACCOUNTING_INTENT_KEYWORDS)
|
||
has_schedule_signal = bool(re.search(r"\d{1,2}月\d{1,2}日", text) or any(keyword in text for keyword in SCHEDULE_KEYWORDS))
|
||
has_analysis_signal = any(keyword in text for keyword in ANALYSIS_KEYWORDS)
|
||
|
||
if has_accounting_signal:
|
||
return AgentRole.EXECUTOR
|
||
|
||
if has_schedule_signal:
|
||
return AgentRole.SCHEDULE_PLANNER
|
||
|
||
if has_analysis_signal:
|
||
return AgentRole.ANALYST
|
||
|
||
if any(keyword in text for keyword in KNOWLEDGE_KEYWORDS):
|
||
return AgentRole.LIBRARIAN
|
||
|
||
if any(pattern in text for pattern in GENERAL_QA_PATTERNS):
|
||
return AgentRole.MASTER
|
||
|
||
if any(keyword in text for keyword in EXECUTION_KEYWORDS):
|
||
return AgentRole.EXECUTOR
|
||
return AgentRole.MASTER
|
||
|
||
|
||
def _choose_sub_commander(role: AgentRole, user_query: str) -> str:
|
||
text = (user_query or "").strip().lower()
|
||
|
||
if role == AgentRole.SCHEDULE_PLANNER:
|
||
if re.search(r"\d{1,2}月\d{1,2}日", text) or any(keyword in text for keyword in SCHEDULE_PLANNING_KEYWORDS):
|
||
return "schedule_planning"
|
||
return "schedule_analysis"
|
||
if role == AgentRole.EXECUTOR:
|
||
if any(keyword in text for keyword in ("论坛", "帖子", "发帖", "指令")):
|
||
return "executor_forum"
|
||
return "executor_tasks"
|
||
if role == AgentRole.LIBRARIAN:
|
||
if any(keyword in text for keyword in ("图谱", "关系", "沉淀", "graph")):
|
||
return "librarian_graph"
|
||
return "librarian_retrieval"
|
||
if role == AgentRole.ANALYST:
|
||
if any(keyword in text for keyword in ("趋势", "风险", "洞察", "建议", "机会")):
|
||
return "analyst_insights"
|
||
return "analyst_progress"
|
||
raise ValueError(f"unsupported role: {role}")
|
||
|
||
|
||
def _is_missing_knowledge_result(tool_result: str | None) -> bool:
|
||
text = (tool_result or "").strip()
|
||
if not text:
|
||
return True
|
||
markers = (
|
||
"未找到相关知识",
|
||
"知识库可能为空",
|
||
"未找到相关网页结果",
|
||
"暂无相关记录",
|
||
"没有找到",
|
||
|
||
)
|
||
return any(marker in text for marker in markers)
|
||
|
||
|
||
def _extract_json_object(content: str) -> str | None:
|
||
text = (content or "").strip()
|
||
if not text:
|
||
return None
|
||
|
||
fenced_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
|
||
if fenced_match:
|
||
return fenced_match.group(1)
|
||
|
||
decoder = json.JSONDecoder()
|
||
for index, char in enumerate(text):
|
||
if char != "{":
|
||
continue
|
||
try:
|
||
_, end = decoder.raw_decode(text[index:])
|
||
return text[index:index + end]
|
||
except json.JSONDecodeError:
|
||
continue
|
||
return None
|
||
|
||
|
||
def _parse_json_action(content: str, allowed_tools: list[str]) -> dict[str, Any] | None:
|
||
json_text = _extract_json_object(_stringify_message_content(content))
|
||
if not json_text:
|
||
return None
|
||
|
||
try:
|
||
payload = json.loads(json_text)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
if not isinstance(payload, dict):
|
||
return None
|
||
|
||
mode = payload.get("mode")
|
||
if mode == "tool_call":
|
||
tool_calls: list[dict[str, Any]] = []
|
||
raw_tool_calls = payload.get("tool_calls", [])
|
||
if not isinstance(raw_tool_calls, list):
|
||
return None
|
||
for item in raw_tool_calls:
|
||
if not isinstance(item, dict):
|
||
return None
|
||
name = item.get("name")
|
||
if name not in allowed_tools:
|
||
return None
|
||
args = item.get("arguments")
|
||
if args is None:
|
||
args = item.get("parameters")
|
||
if not isinstance(args, dict):
|
||
return None
|
||
tool_calls.append(
|
||
{
|
||
"name": name,
|
||
"args": args,
|
||
"reason": item.get("reason"),
|
||
}
|
||
)
|
||
return {"mode": mode, "tool_calls": tool_calls}
|
||
|
||
if mode == "final" and isinstance(payload.get("final_response"), str):
|
||
return {"mode": mode, "final_response": payload["final_response"]}
|
||
|
||
if mode == "clarification" and isinstance(payload.get("clarification_question"), str):
|
||
return {"mode": mode, "clarification_question": payload["clarification_question"]}
|
||
|
||
return None
|
||
|
||
|
||
def _has_active_structured_continuation(state: AgentState) -> bool:
|
||
pending_action = state.get("pending_action") or {}
|
||
routing_decision = state.get("routing_decision") or {}
|
||
continuity_state = state.get("continuity_state") or {}
|
||
|
||
if continuity_state.get("status") != "fresh":
|
||
return False
|
||
if pending_action.get("status") != "pending":
|
||
return False
|
||
if routing_decision.get("reason") != "continue_pending_action":
|
||
return False
|
||
target_agent = routing_decision.get("target_agent")
|
||
return target_agent in _role_values()
|
||
|
||
|
||
def _route_from_structured_continuity(state: AgentState, user_query: str) -> AgentRole | None:
|
||
if not _is_short_confirmation(user_query):
|
||
return None
|
||
if not _has_active_structured_continuation(state):
|
||
return None
|
||
|
||
target_agent = (state.get("routing_decision") or {}).get("target_agent")
|
||
|
||
try:
|
||
return AgentRole(str(target_agent))
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
def _build_structured_continuity_summary(state: AgentState) -> str | None:
|
||
pending_action = state.get("pending_action")
|
||
routing_decision = state.get("routing_decision")
|
||
|
||
if not pending_action or not routing_decision:
|
||
return None
|
||
if not _has_active_structured_continuation(state):
|
||
return None
|
||
|
||
action_type = str(pending_action.get("type") or "unknown")
|
||
action_summary = str(pending_action.get("summary") or "")
|
||
routing_reason = str(routing_decision.get("reason") or "")
|
||
target_agent = str(routing_decision.get("target_agent") or "")
|
||
|
||
lines = [
|
||
"structured_continuity:",
|
||
f"- pending_action.type: {action_type}",
|
||
]
|
||
if action_summary:
|
||
lines.append(f"- pending_action.summary: {action_summary}")
|
||
if target_agent:
|
||
lines.append(f"- routing_decision.target_agent: {target_agent}")
|
||
if routing_reason:
|
||
lines.append(f"- routing_decision.reason: {routing_reason}")
|
||
lines.append("- instruction: continue_pending_action")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _build_system_messages(state: AgentState, system_prompt: str, role: AgentRole, sub_commander: str) -> list[BaseMessage]:
|
||
messages: list[BaseMessage] = [SystemMessage(content=system_prompt)]
|
||
|
||
current_datetime_context = state.get("current_datetime_context")
|
||
if current_datetime_context:
|
||
messages.append(SystemMessage(content=current_datetime_context))
|
||
|
||
continuity_summary = _build_structured_continuity_summary(state)
|
||
if continuity_summary:
|
||
messages.append(SystemMessage(content=continuity_summary))
|
||
|
||
clarification_summary = _build_clarification_summary(state)
|
||
if clarification_summary:
|
||
messages.append(SystemMessage(content=clarification_summary))
|
||
|
||
role_context_map = {
|
||
AgentRole.SCHEDULE_PLANNER: state.get("schedule_context_summary"),
|
||
AgentRole.LIBRARIAN: state.get("knowledge_context"),
|
||
AgentRole.ANALYST: state.get("analysis_report"),
|
||
}
|
||
role_context = role_context_map.get(role)
|
||
if role_context:
|
||
messages.append(SystemMessage(content=f"角色上下文:\n{role_context}"))
|
||
|
||
role_skill_key = ROLE_SKILL_CONTEXT.get(role)
|
||
if role_skill_key:
|
||
skill_context = build_skill_context(role_skill_key)
|
||
if skill_context:
|
||
messages.append(SystemMessage(content=skill_context))
|
||
|
||
messages.append(SystemMessage(content=f"本次应由子指挥官 `{sub_commander}` 接手。"))
|
||
messages.append(SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]))
|
||
return messages
|
||
|
||
|
||
def _maybe_reset_turn_budgets(state: AgentState) -> None:
|
||
messages = state.get("messages") or []
|
||
if not messages:
|
||
state["routing_hops"] = 0
|
||
state["terminated_due_to_loop_guard"] = False
|
||
state["iteration_count"] = 0
|
||
state["tool_round_count"] = 0
|
||
state["retry_count"] = 0
|
||
state["stop_reason"] = None
|
||
state["clarification_needed"] = False
|
||
state["clarification_question"] = None
|
||
state["final_response"] = None
|
||
return
|
||
|
||
last_message_type = getattr(messages[-1], "type", "")
|
||
has_prior_assistant_turn = any(getattr(message, "type", "") == "ai" for message in messages[:-1])
|
||
if last_message_type in {"human", "user"} and has_prior_assistant_turn:
|
||
state["routing_hops"] = 0
|
||
state["terminated_due_to_loop_guard"] = False
|
||
state["iteration_count"] = 0
|
||
state["tool_round_count"] = 0
|
||
state["retry_count"] = 0
|
||
state["stop_reason"] = None
|
||
state["clarification_needed"] = False
|
||
state["clarification_question"] = None
|
||
state["final_response"] = None
|
||
|
||
|
||
def _conversation_history_messages(state: AgentState) -> list[BaseMessage]:
|
||
history = list(state.get("messages", []))
|
||
return [message for message in history if getattr(message, "type", "") != "system"]
|
||
|
||
|
||
def _record_sub_commander(state: AgentState, role: AgentRole, sub_commander: str, user_query: str) -> None:
|
||
state["current_agent"] = role.value
|
||
state["current_sub_commander"] = sub_commander
|
||
state["active_agents"] = _normalize_active_agents(state.get("active_agents"))
|
||
if role not in state["active_agents"]:
|
||
state["active_agents"] = [*state["active_agents"], role]
|
||
state["active_sub_commanders"] = [*(state.get("active_sub_commanders") or []), sub_commander]
|
||
state["sub_commander_trace"] = [
|
||
*(state.get("sub_commander_trace") or []),
|
||
{
|
||
"agent": _role_value(role),
|
||
"sub_commander": sub_commander,
|
||
"query": user_query,
|
||
},
|
||
]
|
||
state["retrieval_trace"] = [
|
||
*(state.get("retrieval_trace") or []),
|
||
{
|
||
"agent": _role_value(role),
|
||
"sub_commander": sub_commander,
|
||
"query": user_query,
|
||
},
|
||
]
|
||
|
||
|
||
def _stop_sub_commander_due_to_budget(state: AgentState, reason: str) -> None:
|
||
state["stop_reason"] = reason
|
||
state["final_response"] = "这次需要处理的步骤有点多,我先停在这里。您可以把目标再明确一点,或让我先只完成其中一步。"
|
||
|
||
|
||
def _guard_sub_commander_budget(state: AgentState, counter_key: str, max_key: str, reason: str) -> bool:
|
||
max_value = _get_state_int(state, max_key)
|
||
current_value = _get_state_int(state, counter_key)
|
||
if max_value > 0 and current_value >= max_value:
|
||
_stop_sub_commander_due_to_budget(state, reason)
|
||
return False
|
||
return True
|
||
|
||
|
||
def _classify_created_entity(tool_name: str) -> dict[str, str] | None:
|
||
mapping = {
|
||
"create_reminder": "reminder",
|
||
"create_goal": "goal",
|
||
"create_todo": "todo",
|
||
"create_schedule_task": "task",
|
||
"create_task": "task",
|
||
"create_forum_post": "forum_post",
|
||
}
|
||
entity_type = mapping.get(tool_name)
|
||
if not entity_type:
|
||
return None
|
||
return {"type": entity_type, "tool": tool_name}
|
||
|
||
|
||
def _build_reminder_time_clarification(args: dict[str, Any]) -> dict[str, Any] | None:
|
||
if not args.get("date") or args.get("reminder_at"):
|
||
return None
|
||
title = str(args.get("title") or args.get("content") or "这件事").strip() or "这件事"
|
||
return {
|
||
"question": f"要把“{title}”提醒在几点?如果您不想特地指定,我也可以默认按当天早上 9 点给您设置。",
|
||
"missing_fields": ["reminder_at"],
|
||
"partial_args": args,
|
||
}
|
||
|
||
|
||
def _has_active_clarification_context(state: AgentState) -> bool:
|
||
clarification_context = state.get("clarification_context") or {}
|
||
if not clarification_context:
|
||
return False
|
||
if not clarification_context.get("question"):
|
||
return False
|
||
owning_agent = clarification_context.get("owning_agent")
|
||
return isinstance(owning_agent, str) and owning_agent in _role_values()
|
||
|
||
|
||
def _clear_clarification_context(state: AgentState) -> None:
|
||
state["clarification_context"] = None
|
||
|
||
|
||
def _write_clarification_context(
|
||
state: AgentState,
|
||
*,
|
||
role: AgentRole,
|
||
sub_commander: str,
|
||
tool_name: str,
|
||
question: str,
|
||
partial_args: dict[str, Any] | None = None,
|
||
missing_fields: list[str] | None = None,
|
||
) -> None:
|
||
state["clarification_context"] = {
|
||
"owning_agent": role.value,
|
||
"owning_sub_commander": sub_commander,
|
||
"target_action": tool_name,
|
||
"question": question,
|
||
"partial_args": dict(partial_args or {}),
|
||
"missing_fields": list(missing_fields or []),
|
||
"status": "pending",
|
||
}
|
||
pending_action = state.get("pending_action") or {}
|
||
if pending_action.get("status") == "pending":
|
||
state["continuity_state"] = {"status": "fresh", "mode": "resume_after_clarification"}
|
||
|
||
|
||
def _route_from_clarification_context(state: AgentState, user_query: str) -> AgentRole | None:
|
||
if not _has_active_clarification_context(state):
|
||
return None
|
||
if _route_agent_from_user_query(user_query) != AgentRole.MASTER:
|
||
return None
|
||
owning_agent = str((state.get("clarification_context") or {}).get("owning_agent") or "")
|
||
try:
|
||
return AgentRole(owning_agent)
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
def _build_clarification_summary(state: AgentState) -> str | None:
|
||
clarification_context = state.get("clarification_context") or {}
|
||
if not _has_active_clarification_context(state):
|
||
return None
|
||
lines = [
|
||
"clarification_context:",
|
||
f"- owning_agent: {clarification_context.get('owning_agent')}",
|
||
f"- owning_sub_commander: {clarification_context.get('owning_sub_commander')}",
|
||
f"- target_action: {clarification_context.get('target_action')}",
|
||
f"- question: {clarification_context.get('question')}",
|
||
]
|
||
missing_fields = clarification_context.get("missing_fields") or []
|
||
partial_args = clarification_context.get("partial_args") or {}
|
||
if missing_fields:
|
||
lines.append(f"- missing_fields: {', '.join(str(field) for field in missing_fields)}")
|
||
if partial_args:
|
||
lines.append(f"- partial_args: {json.dumps(partial_args, ensure_ascii=False)}")
|
||
lines.append("- instruction: merge the user's latest answer into the missing fields and continue the same action")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _build_resumed_clarification_query(state: AgentState, user_query: str) -> str:
|
||
clarification_context = state.get("clarification_context") or {}
|
||
partial_args = clarification_context.get("partial_args") or {}
|
||
target_action = clarification_context.get("target_action") or ""
|
||
if not partial_args or target_action != "create_reminder":
|
||
return user_query
|
||
|
||
parts: list[str] = []
|
||
title = partial_args.get("title") or partial_args.get("content")
|
||
if title:
|
||
parts.append(f"title={title}")
|
||
date = partial_args.get("date")
|
||
if date:
|
||
parts.append(f"date={date}")
|
||
parts.append(f"reminder_at={user_query}")
|
||
return f"继续完成提醒创建,请合并参数:{';'.join(parts)}"
|
||
|
||
|
||
def _prepare_tool_calls_for_execution(
|
||
tool_calls: list[dict[str, Any]],
|
||
state: AgentState,
|
||
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
|
||
prepared_calls: list[dict[str, Any]] = []
|
||
for call in tool_calls:
|
||
tool_name, args = _canonicalize_tool_call(call["name"], dict(call.get("args") or {}))
|
||
normalized_args = normalize_tool_time_arguments(
|
||
tool_name,
|
||
args,
|
||
state.get("current_datetime_context"),
|
||
)
|
||
clarification = None
|
||
if tool_name == "create_reminder":
|
||
clarification = _build_reminder_time_clarification(normalized_args)
|
||
if clarification:
|
||
return [], {
|
||
"tool_name": tool_name,
|
||
"question": clarification["question"],
|
||
"partial_args": clarification.get("partial_args") or normalized_args,
|
||
"missing_fields": clarification.get("missing_fields") or [],
|
||
}
|
||
prepared_calls.append(
|
||
{
|
||
"id": call.get("id"),
|
||
"name": tool_name,
|
||
"args": normalized_args,
|
||
"reason": call.get("reason"),
|
||
}
|
||
)
|
||
return prepared_calls, None
|
||
|
||
|
||
def _canonicalize_tool_call(tool_name: str, args: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||
normalized_name = tool_name
|
||
normalized_args = dict(args)
|
||
|
||
if normalized_name == "create_reminder":
|
||
if normalized_args.get("description") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["content"] = normalized_args["description"]
|
||
if normalized_args.get("reminder_content") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["content"] = normalized_args["reminder_content"]
|
||
if normalized_args.get("reminder_time") and not (
|
||
normalized_args.get("reminder_at")
|
||
or normalized_args.get("datetime")
|
||
or normalized_args.get("at")
|
||
or normalized_args.get("remind_at")
|
||
or normalized_args.get("time")
|
||
):
|
||
normalized_args["time"] = normalized_args["reminder_time"]
|
||
|
||
if normalized_name in {"create_schedule_task", "create_task"}:
|
||
if normalized_args.get("task") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["title"] = normalized_args["task"]
|
||
if normalized_args.get("due_datetime") and not (normalized_args.get("due_date") or normalized_args.get("date")):
|
||
normalized_args["due_date"] = normalized_args["due_datetime"]
|
||
if normalized_args.get("due_time") and not normalized_args.get("due_date"):
|
||
normalized_args["due_date"] = normalized_args["due_time"]
|
||
|
||
if normalized_name == "create_todo":
|
||
if normalized_args.get("task") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["title"] = normalized_args["task"]
|
||
if normalized_args.get("date") and not normalized_args.get("todo_date"):
|
||
normalized_args["todo_date"] = normalized_args["date"]
|
||
if normalized_args.get("due_date") and not normalized_args.get("todo_date"):
|
||
normalized_args["todo_date"] = normalized_args["due_date"]
|
||
if normalized_args.get("due_datetime") and not normalized_args.get("todo_date"):
|
||
normalized_args["todo_date"] = normalized_args["due_datetime"]
|
||
if any(normalized_args.get(key) for key in ("due_datetime", "start_time", "end_time", "due_time")):
|
||
normalized_name = "create_schedule_task"
|
||
if normalized_args.get("due_time") and not normalized_args.get("due_date"):
|
||
normalized_args["due_date"] = normalized_args["due_time"]
|
||
elif normalized_args.get("todo_date") and not normalized_args.get("due_date"):
|
||
normalized_args["due_date"] = normalized_args["todo_date"]
|
||
|
||
if normalized_name == "create_goal":
|
||
if normalized_args.get("task") and not (normalized_args.get("title") or normalized_args.get("content")):
|
||
normalized_args["title"] = normalized_args["task"]
|
||
if normalized_args.get("date") and not normalized_args.get("goal_date"):
|
||
normalized_args["goal_date"] = normalized_args["date"]
|
||
if normalized_args.get("target_date") and not normalized_args.get("goal_date"):
|
||
normalized_args["goal_date"] = normalized_args["target_date"]
|
||
|
||
return normalized_name, normalized_args
|
||
|
||
|
||
async def _invoke_llm(llm, messages: list[BaseMessage], tools: list[Any] | None = None):
|
||
messages = _coalesce_system_messages(messages)
|
||
if tools:
|
||
llm = llm.bind_tools(tools)
|
||
return await llm.ainvoke(messages)
|
||
|
||
|
||
async def _execute_tool_calls(
|
||
tool_calls: list[dict[str, Any]],
|
||
toolset: list[Any],
|
||
state: AgentState,
|
||
) -> tuple[list[dict[str, Any]], str, list[dict[str, str]], list[ToolMessage]]:
|
||
tool_map = {tool.name: tool for tool in toolset}
|
||
normalized_calls: list[dict[str, Any]] = []
|
||
result_lines: list[str] = []
|
||
created_entities: list[dict[str, str]] = []
|
||
tool_messages: list[ToolMessage] = []
|
||
|
||
for call in tool_calls:
|
||
tool_name = call["name"]
|
||
normalized_args = dict(call.get("args") or {})
|
||
tool = tool_map.get(tool_name)
|
||
if tool is None:
|
||
raise ValueError(f"Tool not found: {tool_name}")
|
||
|
||
try:
|
||
if hasattr(tool, "ainvoke"):
|
||
result = await tool.ainvoke(normalized_args)
|
||
else:
|
||
result = await asyncio.to_thread(tool.invoke, normalized_args)
|
||
except Exception as exc:
|
||
logger.exception("Tool execution failed: %s args=%s", tool_name, normalized_args)
|
||
result = f"工具执行失败: {exc}"
|
||
|
||
normalized_call = {
|
||
"id": call.get("id"),
|
||
"name": tool_name,
|
||
"args": normalized_args,
|
||
"reason": call.get("reason"),
|
||
}
|
||
normalized_calls.append(normalized_call)
|
||
result_lines.append(f"[{tool_name}] {result}")
|
||
tool_messages.append(
|
||
ToolMessage(
|
||
content=_stringify_message_content(result),
|
||
tool_call_id=str(call.get("id") or tool_name),
|
||
name=tool_name,
|
||
)
|
||
)
|
||
entity = _classify_created_entity(tool_name)
|
||
if entity and not _tool_result_indicates_failure(result):
|
||
created_entities.append(entity)
|
||
|
||
return normalized_calls, "\n".join(result_lines), created_entities, tool_messages
|
||
|
||
|
||
async def _run_sub_commander(
|
||
state: AgentState,
|
||
role: AgentRole,
|
||
manager_prompt: str,
|
||
user_query: str,
|
||
*,
|
||
use_tools: bool,
|
||
summary_target: str | None = None,
|
||
):
|
||
state["clarification_needed"] = False
|
||
state["clarification_question"] = None
|
||
state["stop_reason"] = None
|
||
|
||
llm = _get_llm_for_state(state)
|
||
capabilities = _resolve_capabilities(state, llm)
|
||
if _has_active_clarification_context(state) and role.value == str((state.get("clarification_context") or {}).get("owning_agent") or ""):
|
||
user_query = _build_resumed_clarification_query(state, user_query)
|
||
sub_commander = _choose_sub_commander(role, user_query)
|
||
_record_sub_commander(state, role, sub_commander, user_query)
|
||
|
||
toolset = SUB_COMMANDER_TOOLSETS.get(sub_commander, []) if use_tools else []
|
||
if (
|
||
role == AgentRole.EXECUTOR
|
||
and _is_short_confirmation(user_query)
|
||
and _previous_turn_completed_reminder_creation(state)
|
||
):
|
||
state["tool_calls"] = []
|
||
state["last_tool_result"] = None
|
||
state["final_response"] = "上一条提醒已经创建好了。若您现在要新建别的内容,请直接告诉我要创建什么。"
|
||
history_messages = list(state.get("messages", []))
|
||
history_messages.append(AIMessage(content=state["final_response"]))
|
||
state["messages"] = history_messages
|
||
state["should_respond"] = True
|
||
return state
|
||
base_messages = _build_system_messages(state, manager_prompt, role, sub_commander)
|
||
conversation_history = _conversation_history_messages(state)
|
||
if conversation_history and getattr(conversation_history[-1], "type", "") in {"human", "user"}:
|
||
conversation_history = conversation_history[:-1]
|
||
user_message = HumanMessage(content=f"用户请求: {user_query}")
|
||
working_messages = [*base_messages, *conversation_history, user_message]
|
||
|
||
state["tool_calls"] = []
|
||
state["last_tool_result"] = None
|
||
state["tool_strategy_used"] = None
|
||
state["fallback_parse_error"] = None
|
||
|
||
if not _guard_sub_commander_budget(state, "tool_round_count", "max_tool_rounds", "max_tool_rounds_exceeded"):
|
||
pass
|
||
elif not _guard_sub_commander_budget(state, "retry_count", "max_retries", "max_retries_exceeded"):
|
||
pass
|
||
elif not toolset:
|
||
if _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
response = await _invoke_llm(llm, working_messages)
|
||
state["final_response"] = _stringify_message_content(response.content)
|
||
elif capabilities.supports_native_tools:
|
||
state["tool_strategy_used"] = "native"
|
||
bound_llm = llm.bind_tools(toolset)
|
||
while state.get("final_response") is None and not state.get("clarification_needed"):
|
||
if not _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
break
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
response = await _invoke_llm(bound_llm, working_messages)
|
||
tool_calls = getattr(response, "tool_calls", None) or []
|
||
if tool_calls:
|
||
if not _guard_sub_commander_budget(state, "tool_round_count", "max_tool_rounds", "max_tool_rounds_exceeded"):
|
||
break
|
||
prepared_calls, clarification = _prepare_tool_calls_for_execution(tool_calls, state)
|
||
if clarification:
|
||
state["clarification_needed"] = True
|
||
state["clarification_question"] = clarification["question"]
|
||
_write_clarification_context(
|
||
state,
|
||
role=role,
|
||
sub_commander=sub_commander,
|
||
tool_name=clarification["tool_name"],
|
||
question=clarification["question"],
|
||
partial_args=clarification.get("partial_args"),
|
||
missing_fields=clarification.get("missing_fields"),
|
||
)
|
||
state["stop_reason"] = "clarification_needed"
|
||
state["final_response"] = clarification["question"]
|
||
break
|
||
state["tool_round_count"] = int(state.get("tool_round_count") or 0) + 1
|
||
assistant_tool_message = AIMessage(
|
||
content=_stringify_message_content(getattr(response, "content", "")),
|
||
tool_calls=tool_calls,
|
||
)
|
||
normalized_calls, tool_result, created_entities, tool_messages = await _execute_tool_calls(
|
||
prepared_calls,
|
||
toolset,
|
||
state,
|
||
)
|
||
state["tool_calls"] = normalized_calls
|
||
state["last_tool_result"] = tool_result
|
||
state["created_entities"] = [*(state.get("created_entities") or []), *created_entities]
|
||
if created_entities:
|
||
_clear_clarification_context(state)
|
||
if role == AgentRole.SCHEDULE_PLANNER and _should_clear_schedule_creation_continuity(state, created_entities):
|
||
_clear_structured_continuity(state)
|
||
working_messages = [*working_messages, assistant_tool_message, *tool_messages]
|
||
if sub_commander == "librarian_retrieval" and _is_missing_knowledge_result(tool_result):
|
||
working_messages.append(SystemMessage(content="如果检索工具没有找到证据,可以直接基于你的常识给出清晰回答,不要机械地说不知道。"))
|
||
continue
|
||
state["final_response"] = _stringify_message_content(response.content)
|
||
else:
|
||
state["tool_strategy_used"] = "json_fallback"
|
||
allowed_tools = [tool.name for tool in toolset]
|
||
while state.get("final_response") is None and not state.get("clarification_needed"):
|
||
if not _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
break
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
parsed = None
|
||
retry_instruction: BaseMessage | None = None
|
||
while parsed is None:
|
||
response = await _invoke_llm(
|
||
llm,
|
||
[
|
||
*working_messages,
|
||
SystemMessage(content=JSON_ACTION_FALLBACK_PROMPT),
|
||
SystemMessage(content=f"本次可用工具列表: {', '.join(allowed_tools)}"),
|
||
*([retry_instruction] if retry_instruction else []),
|
||
],
|
||
)
|
||
response_text = _stringify_message_content(response.content)
|
||
parsed = _parse_json_action(response_text, allowed_tools)
|
||
if parsed is None and response_text.strip() and state.get("tool_round_count"):
|
||
state["fallback_parse_error"] = None
|
||
state["final_response"] = response_text.strip()
|
||
break
|
||
if parsed is not None:
|
||
state["fallback_parse_error"] = None
|
||
break
|
||
if not _guard_sub_commander_budget(state, "iteration_count", "max_iterations", "max_iterations_exceeded"):
|
||
parsed = None
|
||
break
|
||
if int(state.get("retry_count") or 0) >= int(state.get("max_retries") or 0):
|
||
state["fallback_parse_error"] = "invalid_json_action"
|
||
state["final_response"] = "这次内部动作解析没整理好,不过您的意思我接住了。您再说一遍要我执行的内容,我只回结果,不展示内部调用细节。"
|
||
break
|
||
state["iteration_count"] = int(state.get("iteration_count") or 0) + 1
|
||
state["retry_count"] = int(state.get("retry_count") or 0) + 1
|
||
retry_instruction = SystemMessage(content="上一次输出不是有效 JSON。请严格只返回合法 JSON,不要加解释。")
|
||
if state.get("final_response") is not None:
|
||
break
|
||
if parsed is None:
|
||
break
|
||
if parsed["mode"] == "final":
|
||
state["final_response"] = parsed["final_response"]
|
||
break
|
||
if parsed["mode"] == "clarification":
|
||
state["clarification_needed"] = True
|
||
state["clarification_question"] = parsed["clarification_question"]
|
||
_write_clarification_context(
|
||
state,
|
||
role=role,
|
||
sub_commander=sub_commander,
|
||
tool_name="clarification",
|
||
question=parsed["clarification_question"],
|
||
)
|
||
state["stop_reason"] = "clarification_needed"
|
||
state["final_response"] = parsed["clarification_question"]
|
||
break
|
||
if not _guard_sub_commander_budget(state, "tool_round_count", "max_tool_rounds", "max_tool_rounds_exceeded"):
|
||
break
|
||
prepared_calls, clarification = _prepare_tool_calls_for_execution(parsed["tool_calls"], state)
|
||
if clarification:
|
||
state["clarification_needed"] = True
|
||
state["clarification_question"] = clarification["question"]
|
||
_write_clarification_context(
|
||
state,
|
||
role=role,
|
||
sub_commander=sub_commander,
|
||
tool_name=clarification["tool_name"],
|
||
question=clarification["question"],
|
||
partial_args=clarification.get("partial_args"),
|
||
missing_fields=clarification.get("missing_fields"),
|
||
)
|
||
state["stop_reason"] = "clarification_needed"
|
||
state["final_response"] = clarification["question"]
|
||
break
|
||
state["tool_round_count"] = int(state.get("tool_round_count") or 0) + 1
|
||
normalized_calls, tool_result, created_entities, tool_messages = await _execute_tool_calls(
|
||
prepared_calls,
|
||
toolset,
|
||
state,
|
||
)
|
||
state["tool_calls"] = normalized_calls
|
||
state["last_tool_result"] = tool_result
|
||
state["created_entities"] = [*(state.get("created_entities") or []), *created_entities]
|
||
if role == AgentRole.SCHEDULE_PLANNER and _should_clear_schedule_creation_continuity(state, created_entities):
|
||
_clear_structured_continuity(state)
|
||
working_messages = [*working_messages, *tool_messages]
|
||
if sub_commander == "librarian_retrieval" and _is_missing_knowledge_result(tool_result):
|
||
working_messages.append(SystemMessage(content="如果检索工具没有找到证据,可以直接基于你的常识给出清晰回答,不要机械地说不知道。"))
|
||
|
||
if summary_target:
|
||
state[_summary_state_key(summary_target)] = state.get("final_response")
|
||
|
||
final_response_text = state.get("final_response")
|
||
if not state.get("clarification_needed") and final_response_text:
|
||
_clear_clarification_context(state)
|
||
if (
|
||
role == AgentRole.SCHEDULE_PLANNER
|
||
and isinstance(final_response_text, str)
|
||
and _is_schedule_creation_confirmation_response(final_response_text)
|
||
):
|
||
_write_schedule_creation_continuity(state, user_query)
|
||
|
||
history_messages = list(state.get("messages", []))
|
||
final_response = state.get("final_response")
|
||
if isinstance(final_response, str):
|
||
history_messages.append(AIMessage(content=final_response))
|
||
|
||
state["messages"] = history_messages
|
||
state["should_respond"] = True
|
||
return state
|
||
|
||
|
||
def _can_delegate_within_hop_budget(state: AgentState) -> bool:
|
||
return _get_state_int(state, "routing_hops") < _get_state_int(state, "max_routing_hops")
|
||
|
||
|
||
def _stop_due_to_loop_guard(state: AgentState) -> AgentState:
|
||
state["terminated_due_to_loop_guard"] = True
|
||
state["final_response"] = "这次需要处理的步骤有点多,我先停在这里。您可以把目标再明确一点,或让我先只完成其中一步。"
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
|
||
async def master_node(state: AgentState) -> AgentState:
|
||
_maybe_reset_turn_budgets(state)
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content).strip() if user_messages else ""
|
||
|
||
state["current_agent"] = _normalize_current_agent(state.get("current_agent"))
|
||
state["active_agents"] = _normalize_active_agents(state.get("active_agents"))
|
||
|
||
if _is_simple_greeting(user_query):
|
||
state["final_response"] = "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。"
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
if _is_identity_question(user_query):
|
||
state["final_response"] = (
|
||
"我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:"
|
||
"帮您看清问题、压缩路径、把事情往前推进。"
|
||
)
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
if _is_capability_question(user_query):
|
||
state["final_response"] = (
|
||
"主要做三件事。\n"
|
||
"- 帮您判断:看问题本质、梳理取舍、给出方向\n"
|
||
"- 帮您收束:把复杂内容理顺,把重点拎出来\n"
|
||
"- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n"
|
||
"如果您现在有具体目标,我可以直接进入处理。"
|
||
)
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=state["final_response"])]
|
||
return state
|
||
|
||
structured_continuity_route = _route_from_structured_continuity(state, user_query)
|
||
clarification_route = _route_from_clarification_context(state, user_query)
|
||
if structured_continuity_route is not None:
|
||
routed_agent = structured_continuity_route
|
||
elif clarification_route is not None:
|
||
routed_agent = clarification_route
|
||
elif _is_short_confirmation(user_query) and _previous_turn_proposed_schedule_creation(state.get("messages", [])):
|
||
routed_agent = AgentRole.SCHEDULE_PLANNER
|
||
else:
|
||
routed_agent = _route_agent_from_user_query(user_query)
|
||
if routed_agent != AgentRole.MASTER:
|
||
if not _can_delegate_within_hop_budget(state):
|
||
return _stop_due_to_loop_guard(state)
|
||
state["routing_hops"] = int(state.get("routing_hops") or 0) + 1
|
||
state["current_agent"] = routed_agent.value
|
||
state["next_step"] = routed_agent.value
|
||
if routed_agent not in state["active_agents"]:
|
||
state["active_agents"] = [*state["active_agents"], routed_agent]
|
||
state["agent_trace"] = [*(state.get("agent_trace") or [AgentRole.MASTER.value]), routed_agent.value]
|
||
return state
|
||
|
||
llm = _get_llm_for_state(state)
|
||
response = await _invoke_llm(llm, [SystemMessage(content=MASTER_SYSTEM_PROMPT), *state["messages"]])
|
||
content = _stringify_message_content(response.content).strip()
|
||
|
||
routed_agent = _route_agent_from_user_query(content)
|
||
if routed_agent != AgentRole.MASTER and len(content) <= 64:
|
||
if not _can_delegate_within_hop_budget(state):
|
||
return _stop_due_to_loop_guard(state)
|
||
state["routing_hops"] = int(state.get("routing_hops") or 0) + 1
|
||
state["current_agent"] = routed_agent.value
|
||
state["next_step"] = routed_agent.value
|
||
if routed_agent not in state["active_agents"]:
|
||
state["active_agents"] = [*state["active_agents"], routed_agent]
|
||
state["agent_trace"] = [*(state.get("agent_trace") or [AgentRole.MASTER.value]), routed_agent.value]
|
||
return state
|
||
|
||
state["final_response"] = content
|
||
state["messages"] = [*state.get("messages", []), AIMessage(content=content)]
|
||
return state
|
||
|
||
|
||
async def planner_node(state: AgentState) -> AgentState:
|
||
state["next_step"] = None
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
if _has_active_clarification_context(state):
|
||
user_query = _build_resumed_clarification_query(state, user_query)
|
||
elif _is_short_confirmation(user_query) and _previous_turn_proposed_schedule_creation(state.get("messages", [])):
|
||
user_query = _expand_schedule_confirmation_query(user_query, state.get("messages", []))
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.SCHEDULE_PLANNER,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.SCHEDULE_PLANNER],
|
||
user_query,
|
||
use_tools=True,
|
||
summary_target="schedule_context_summary",
|
||
)
|
||
|
||
|
||
async def executor_node(state: AgentState) -> AgentState:
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.EXECUTOR,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.EXECUTOR],
|
||
user_query,
|
||
use_tools=True,
|
||
)
|
||
|
||
|
||
async def librarian_node(state: AgentState) -> AgentState:
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.LIBRARIAN,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.LIBRARIAN],
|
||
user_query,
|
||
use_tools=True,
|
||
summary_target="knowledge_context",
|
||
)
|
||
|
||
|
||
async def analyst_node(state: AgentState) -> AgentState:
|
||
user_messages = _filter_user_messages(state["messages"])
|
||
user_query = _stringify_message_content(user_messages[-1].content) if user_messages else ""
|
||
return await _run_sub_commander(
|
||
state,
|
||
AgentRole.ANALYST,
|
||
ROLE_SYSTEM_PROMPTS[AgentRole.ANALYST],
|
||
user_query,
|
||
use_tools=True,
|
||
summary_target="analysis_report",
|
||
)
|
||
|
||
|
||
def route_agent(state: AgentState) -> str:
|
||
if state.get("final_response"):
|
||
return END
|
||
next_step = _role_value(state.get("next_step"))
|
||
if next_step in _role_values():
|
||
return next_step
|
||
return _role_value(state.get("current_agent") or AgentRole.MASTER)
|
||
|
||
|
||
def _compile_graph(graph: StateGraph, callbacks: list | None = None):
|
||
if callbacks:
|
||
try:
|
||
return graph.compile(callbacks=callbacks)
|
||
except TypeError as exc:
|
||
if "callbacks" not in str(exc):
|
||
raise
|
||
return graph.compile()
|
||
|
||
|
||
def create_agent_graph(callbacks: list | None = None):
|
||
graph = StateGraph(AgentState)
|
||
|
||
graph.add_node(AgentRole.MASTER.value, master_node)
|
||
graph.add_node(AgentRole.SCHEDULE_PLANNER.value, planner_node)
|
||
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||
graph.add_node(AgentRole.ANALYST.value, analyst_node)
|
||
|
||
graph.set_entry_point(AgentRole.MASTER.value)
|
||
graph.add_conditional_edges(
|
||
AgentRole.MASTER.value,
|
||
route_agent,
|
||
{
|
||
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
|
||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||
END: END,
|
||
},
|
||
)
|
||
|
||
for role in (
|
||
AgentRole.SCHEDULE_PLANNER,
|
||
AgentRole.EXECUTOR,
|
||
AgentRole.LIBRARIAN,
|
||
AgentRole.ANALYST,
|
||
):
|
||
graph.add_edge(role.value, END)
|
||
|
||
return _compile_graph(graph, callbacks=callbacks)
|
||
|
||
|
||
_agent_graph = None
|
||
|
||
|
||
def get_agent_graph(callbacks: list | None = None):
|
||
global _agent_graph
|
||
if _agent_graph is None:
|
||
from app.config_tracing import get_langsmith_callbacks
|
||
|
||
langsmith_callbacks = get_langsmith_callbacks()
|
||
all_callbacks = (callbacks or []) + langsmith_callbacks
|
||
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
|
||
return _agent_graph
|
||
|
||
|
||
__all__ = [
|
||
"_choose_sub_commander",
|
||
"_parse_json_action",
|
||
"_route_agent_from_user_query",
|
||
"_run_sub_commander",
|
||
"create_agent_graph",
|
||
"get_agent_graph",
|
||
"master_node",
|
||
]
|