Files
JARVIS/backend/app/agents/graph.py
2026-03-21 11:29:57 +08:00

287 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Jarvis LangGraph Agent 主图定义
"""
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
from app.agents.state import AgentState, AgentRole
from app.agents.prompts import (
MASTER_SYSTEM_PROMPT,
PLANNER_SYSTEM_PROMPT,
EXECUTOR_SYSTEM_PROMPT,
LIBRARIAN_SYSTEM_PROMPT,
ANALYST_SYSTEM_PROMPT,
)
from app.agents.tools import ALL_TOOLS
from app.agents.skill_registry import build_skill_context
from app.services.llm_service import get_llm
def _msg_type(msg: BaseMessage) -> str:
"""Get message type, handles both .type (new) and .role (old) attribute names."""
return getattr(msg, "type", None) or getattr(msg, "role", "human")
def _filter_user_messages(messages: list) -> list[BaseMessage]:
return [m for m in messages if _msg_type(m) in ("human", "user")]
# ===================== 节点定义 (async) =====================
async def master_node(state: AgentState) -> AgentState:
"""主Agent节点: 理解用户意图决定调用哪个子Agent"""
llm = get_llm()
messages: list[BaseMessage] = state["messages"]
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
# 注入记忆上下文
memory_ctx = state.get("memory_context")
if memory_ctx:
system_msgs.append(
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
)
response: AIMessage = await llm.invoke(system_msgs + messages)
content = response.content.strip().lower()
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
next_agent = AgentRole.LIBRARIAN
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
next_agent = AgentRole.PLANNER
elif any(kw in content for kw in ["执行", "", "操作", "创建", "更新"]):
next_agent = AgentRole.EXECUTOR
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
next_agent = AgentRole.ANALYST
else:
state["final_response"] = response.content
state["should_respond"] = True
return state
state["current_agent"] = next_agent
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
state["should_respond"] = True
return state
async def planner_node(state: AgentState) -> AgentState:
"""规划Agent节点: 制定计划,拆解任务步骤"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=PLANNER_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("planner")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.invoke(
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
plan_text = response.content
steps = []
for i, line in enumerate(plan_text.split("\n")):
if line.strip() and (line[0].isdigit() or "- " in line):
steps.append({"step": i + 1, "description": line.strip()})
state["plan"] = plan_text
state["plan_steps"] = steps
state["final_response"] = plan_text
state["should_respond"] = True
return state
async def executor_node(state: AgentState) -> AgentState:
"""执行Agent节点: 调用工具执行具体任务"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=EXECUTOR_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("executor")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.bind_tools(ALL_TOOLS).invoke(
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["should_respond"] = True
return state
async def librarian_node(state: AgentState) -> AgentState:
"""知识管理员节点: 管理知识库和知识图谱"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("librarian")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.bind_tools(ALL_TOOLS).invoke(
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["knowledge_context"] = state.get("last_tool_result", "")
state["should_respond"] = True
return state
async def analyst_node(state: AgentState) -> AgentState:
"""分析师节点: 分析工作数据,生成报告"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=ANALYST_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("analyst")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.bind_tools(ALL_TOOLS).invoke(
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
[SystemMessage(content=ANALYST_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["analysis_report"] = state.get("final_response", "")
state["should_respond"] = True
return state
def route_agent(state: AgentState) -> str:
"""路由函数: 决定下一个节点"""
if state.get("final_response"):
return END
return state.get("current_agent", AgentRole.MASTER).value
# ===================== 构建图 =====================
def create_agent_graph(callbacks: list | None = None):
graph = StateGraph(AgentState)
graph.add_node(AgentRole.MASTER.value, master_node)
graph.add_node(AgentRole.PLANNER.value, planner_node)
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
graph.add_node(AgentRole.ANALYST.value, analyst_node)
graph.set_entry_point(AgentRole.MASTER.value)
graph.add_conditional_edges(
AgentRole.MASTER.value,
route_agent,
{
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
END: END,
}
)
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
graph.add_edge(role.value, END)
return graph.compile(callbacks=callbacks)
_agent_graph = None
def get_agent_graph(callbacks: list | None = None):
"""
获取编译好的 Agent 图(单例缓存)。
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
如需变更 Callbacks如修改 LANGCHAIN_PROJECT需重启服务。
Args:
callbacks: 可选的额外 Callbacks会与全局 LangSmith Callbacks 合并
"""
global _agent_graph
if _agent_graph is None:
from app.config_tracing import get_langsmith_callbacks
langsmith_callbacks = get_langsmith_callbacks()
all_callbacks = (callbacks or []) + langsmith_callbacks
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
return _agent_graph