Files
JARVIS/backend/app/agents/graph.py

394 lines
14 KiB
Python
Raw Normal View History

2026-03-21 10:13:29 +08:00
"""
Jarvis LangGraph Agent 主图定义
"""
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
from app.agents.state import AgentState, AgentRole
from app.agents.prompts import (
MASTER_SYSTEM_PROMPT,
PLANNER_SYSTEM_PROMPT,
EXECUTOR_SYSTEM_PROMPT,
LIBRARIAN_SYSTEM_PROMPT,
ANALYST_SYSTEM_PROMPT,
)
from app.agents.tools import ALL_TOOLS
from app.agents.skill_registry import build_skill_context
2026-03-21 10:13:29 +08:00
from app.services.llm_service import get_llm
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
import httpx
def _create_llm_from_config(config: dict):
"""根据用户模型配置创建 LLM 实例"""
provider = config.get("provider", "openai")
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
if provider == "openai" or provider == "deepseek" or provider == "custom":
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
return ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
return ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
def _get_llm_for_state(state: AgentState):
"""从 state 获取 LLM 实例,优先使用用户配置的模型"""
user_llm_config = state.get("user_llm_config")
if user_llm_config:
return _create_llm_from_config(user_llm_config)
return get_llm()
async def _ainvoke(llm, messages: list[BaseMessage]):
ainvoke = getattr(llm, "ainvoke", None)
if callable(ainvoke):
return await ainvoke(messages)
return await llm.invoke(messages)
async def _ainvoke_with_tools(llm, messages: list[BaseMessage]):
bound_llm = llm.bind_tools(ALL_TOOLS)
if hasattr(bound_llm, "ainvoke"):
return await bound_llm.ainvoke(messages)
return await bound_llm.invoke(messages)
def _compile_graph(graph: StateGraph, callbacks: list | None = None):
if callbacks:
try:
return graph.compile(callbacks=callbacks)
except TypeError as exc:
if "callbacks" not in str(exc):
raise
return graph.compile()
2026-03-21 10:13:29 +08:00
def _msg_type(msg: BaseMessage) -> str:
"""Get message type, handles both .type (new) and .role (old) attribute names."""
return getattr(msg, "type", None) or getattr(msg, "role", "human")
def _filter_user_messages(messages: list) -> list[BaseMessage]:
return [m for m in messages if _msg_type(m) in ("human", "user")]
def _normalize_user_text(text: str) -> str:
return (text or "").strip().lower()
def _is_simple_greeting(text: str) -> bool:
normalized = _normalize_user_text(text)
return normalized in {"你好", "您好", "", "早上好", "在吗", "", "hi", "hello"}
def _is_identity_question(text: str) -> bool:
normalized = _normalize_user_text(text)
return normalized in {"你是谁", "你是誰"}
def _is_capability_question(text: str) -> bool:
normalized = _normalize_user_text(text)
return normalized in {"你能做什么", "你可以做什么", "你会做什么"}
2026-03-21 10:13:29 +08:00
# ===================== 节点定义 (async) =====================
async def master_node(state: AgentState) -> AgentState:
"""主Agent节点: 理解用户意图决定调用哪个子Agent"""
messages: list[BaseMessage] = state["messages"]
user_msgs = _filter_user_messages(messages)
user_query = user_msgs[-1].content.strip() if user_msgs else ""
if _is_simple_greeting(user_query):
state["final_response"] = "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。"
state["should_respond"] = True
return state
2026-03-21 10:13:29 +08:00
if _is_identity_question(user_query):
state["final_response"] = "我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。"
state["should_respond"] = True
return state
if _is_capability_question(user_query):
state["final_response"] = "主要做三件事。\n- 帮您判断:看问题本质、梳理取舍、给出方向\n- 帮您收束:把复杂内容理顺,把重点拎出来\n- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n如果您现在有具体目标,我可以直接进入处理。"
state["should_respond"] = True
return state
llm = _get_llm_for_state(state)
2026-03-21 10:13:29 +08:00
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
# 注入记忆上下文
memory_ctx = state.get("memory_context")
if memory_ctx:
system_msgs.append(
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
)
response: AIMessage = await _ainvoke(llm,system_msgs + messages)
2026-03-21 10:13:29 +08:00
content = response.content.strip().lower()
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
next_agent = AgentRole.LIBRARIAN
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
next_agent = AgentRole.PLANNER
elif any(kw in content for kw in ["执行", "", "操作", "创建", "更新"]):
next_agent = AgentRole.EXECUTOR
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
next_agent = AgentRole.ANALYST
else:
state["final_response"] = response.content
state["should_respond"] = True
return state
state["current_agent"] = next_agent
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
state["should_respond"] = True
return state
async def planner_node(state: AgentState) -> AgentState:
"""规划Agent节点: 制定计划,拆解任务步骤"""
llm = _get_llm_for_state(state)
2026-03-21 10:13:29 +08:00
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=PLANNER_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("planner")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await _ainvoke(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
2026-03-21 10:13:29 +08:00
)
plan_text = response.content
steps = []
for i, line in enumerate(plan_text.split("\n")):
if line.strip() and (line[0].isdigit() or "- " in line):
steps.append({"step": i + 1, "description": line.strip()})
state["plan"] = plan_text
state["plan_steps"] = steps
state["final_response"] = plan_text
state["should_respond"] = True
return state
async def executor_node(state: AgentState) -> AgentState:
"""执行Agent节点: 调用工具执行具体任务"""
llm = _get_llm_for_state(state)
2026-03-21 10:13:29 +08:00
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=EXECUTOR_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("executor")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await _ainvoke_with_tools(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
2026-03-21 10:13:29 +08:00
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await _ainvoke(llm,
2026-03-21 10:13:29 +08:00
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["should_respond"] = True
return state
async def librarian_node(state: AgentState) -> AgentState:
"""知识管理员节点: 管理知识库和知识图谱"""
llm = _get_llm_for_state(state)
2026-03-21 10:13:29 +08:00
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("librarian")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await _ainvoke_with_tools(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
2026-03-21 10:13:29 +08:00
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await _ainvoke(llm,
2026-03-21 10:13:29 +08:00
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["knowledge_context"] = state.get("last_tool_result", "")
state["should_respond"] = True
return state
async def analyst_node(state: AgentState) -> AgentState:
"""分析师节点: 分析工作数据,生成报告"""
llm = _get_llm_for_state(state)
2026-03-21 10:13:29 +08:00
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
system_msgs = [SystemMessage(content=ANALYST_SYSTEM_PROMPT)]
skill_ctx = build_skill_context("analyst")
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await _ainvoke_with_tools(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
2026-03-21 10:13:29 +08:00
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await _ainvoke(llm,
2026-03-21 10:13:29 +08:00
[SystemMessage(content=ANALYST_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["analysis_report"] = state.get("final_response", "")
state["should_respond"] = True
return state
def route_agent(state: AgentState) -> str:
"""路由函数: 决定下一个节点"""
if state.get("final_response"):
return END
return state.get("current_agent", AgentRole.MASTER).value
# ===================== 构建图 =====================
def create_agent_graph(callbacks: list | None = None):
graph = StateGraph(AgentState)
graph.add_node(AgentRole.MASTER.value, master_node)
graph.add_node(AgentRole.PLANNER.value, planner_node)
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
graph.add_node(AgentRole.ANALYST.value, analyst_node)
graph.set_entry_point(AgentRole.MASTER.value)
graph.add_conditional_edges(
AgentRole.MASTER.value,
route_agent,
{
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
END: END,
}
)
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
graph.add_edge(role.value, END)
return _compile_graph(graph, callbacks=callbacks)
2026-03-21 10:13:29 +08:00
_agent_graph = None
def get_agent_graph(callbacks: list | None = None):
"""
获取编译好的 Agent 单例缓存
Callbacks 在首次编译时固定注入后续调用忽略 callbacks 参数
如需变更 Callbacks如修改 LANGCHAIN_PROJECT需重启服务
Args:
callbacks: 可选的额外 Callbacks会与全局 LangSmith Callbacks 合并
"""
global _agent_graph
if _agent_graph is None:
from app.config_tracing import get_langsmith_callbacks
langsmith_callbacks = get_langsmith_callbacks()
all_callbacks = (callbacks or []) + langsmith_callbacks
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
return _agent_graph