Add FastAPI backend with agent system
This commit is contained in:
265
backend/app/agents/graph.py
Normal file
265
backend/app/agents/graph.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Jarvis LangGraph Agent 主图定义
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from app.agents.state import AgentState, AgentRole
|
||||
from app.agents.prompts import (
|
||||
MASTER_SYSTEM_PROMPT,
|
||||
PLANNER_SYSTEM_PROMPT,
|
||||
EXECUTOR_SYSTEM_PROMPT,
|
||||
LIBRARIAN_SYSTEM_PROMPT,
|
||||
ANALYST_SYSTEM_PROMPT,
|
||||
)
|
||||
from app.agents.tools import ALL_TOOLS
|
||||
from app.services.llm_service import get_llm
|
||||
|
||||
|
||||
def _msg_type(msg: BaseMessage) -> str:
|
||||
"""Get message type, handles both .type (new) and .role (old) attribute names."""
|
||||
return getattr(msg, "type", None) or getattr(msg, "role", "human")
|
||||
|
||||
|
||||
def _filter_user_messages(messages: list) -> list[BaseMessage]:
|
||||
return [m for m in messages if _msg_type(m) in ("human", "user")]
|
||||
|
||||
|
||||
# ===================== 节点定义 (async) =====================
|
||||
|
||||
async def master_node(state: AgentState) -> AgentState:
|
||||
"""主Agent节点: 理解用户意图,决定调用哪个子Agent"""
|
||||
llm = get_llm()
|
||||
messages: list[BaseMessage] = state["messages"]
|
||||
|
||||
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
|
||||
|
||||
# 注入记忆上下文
|
||||
memory_ctx = state.get("memory_context")
|
||||
if memory_ctx:
|
||||
system_msgs.append(
|
||||
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
|
||||
)
|
||||
|
||||
response: AIMessage = await llm.invoke(system_msgs + messages)
|
||||
content = response.content.strip().lower()
|
||||
|
||||
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
|
||||
next_agent = AgentRole.LIBRARIAN
|
||||
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
|
||||
next_agent = AgentRole.PLANNER
|
||||
elif any(kw in content for kw in ["执行", "做", "操作", "创建", "更新"]):
|
||||
next_agent = AgentRole.EXECUTOR
|
||||
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
|
||||
next_agent = AgentRole.ANALYST
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
state["current_agent"] = next_agent
|
||||
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def planner_node(state: AgentState) -> AgentState:
|
||||
"""规划Agent节点: 制定计划,拆解任务步骤"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.invoke(
|
||||
[SystemMessage(content=PLANNER_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
plan_text = response.content
|
||||
steps = []
|
||||
for i, line in enumerate(plan_text.split("\n")):
|
||||
if line.strip() and (line[0].isdigit() or "- " in line):
|
||||
steps.append({"step": i + 1, "description": line.strip()})
|
||||
|
||||
state["plan"] = plan_text
|
||||
state["plan_steps"] = steps
|
||||
state["final_response"] = plan_text
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def executor_node(state: AgentState) -> AgentState:
|
||||
"""执行Agent节点: 调用工具执行具体任务"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def librarian_node(state: AgentState) -> AgentState:
|
||||
"""知识管理员节点: 管理知识库和知识图谱"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["knowledge_context"] = state.get("last_tool_result", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def analyst_node(state: AgentState) -> AgentState:
|
||||
"""分析师节点: 分析工作数据,生成报告"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=ANALYST_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=ANALYST_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["analysis_report"] = state.get("final_response", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
def route_agent(state: AgentState) -> str:
|
||||
"""路由函数: 决定下一个节点"""
|
||||
if state.get("final_response"):
|
||||
return END
|
||||
return state.get("current_agent", AgentRole.MASTER).value
|
||||
|
||||
|
||||
# ===================== 构建图 =====================
|
||||
|
||||
def create_agent_graph(callbacks: list | None = None):
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
graph.add_node(AgentRole.MASTER.value, master_node)
|
||||
graph.add_node(AgentRole.PLANNER.value, planner_node)
|
||||
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||||
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||||
graph.add_node(AgentRole.ANALYST.value, analyst_node)
|
||||
|
||||
graph.set_entry_point(AgentRole.MASTER.value)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
AgentRole.MASTER.value,
|
||||
route_agent,
|
||||
{
|
||||
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
|
||||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||||
END: END,
|
||||
}
|
||||
)
|
||||
|
||||
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
|
||||
graph.add_edge(role.value, END)
|
||||
|
||||
return graph.compile(callbacks=callbacks)
|
||||
|
||||
|
||||
_agent_graph = None
|
||||
|
||||
|
||||
def get_agent_graph(callbacks: list | None = None):
|
||||
"""
|
||||
获取编译好的 Agent 图(单例缓存)。
|
||||
|
||||
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
|
||||
如需变更 Callbacks(如修改 LANGCHAIN_PROJECT),需重启服务。
|
||||
|
||||
Args:
|
||||
callbacks: 可选的额外 Callbacks,会与全局 LangSmith Callbacks 合并
|
||||
"""
|
||||
global _agent_graph
|
||||
if _agent_graph is None:
|
||||
from app.config_tracing import get_langsmith_callbacks
|
||||
langsmith_callbacks = get_langsmith_callbacks()
|
||||
all_callbacks = (callbacks or []) + langsmith_callbacks
|
||||
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
|
||||
return _agent_graph
|
||||
Reference in New Issue
Block a user