Update agent graph orchestration prompts

Refresh the agent graph state and prompt wiring so the newer backend and
frontend orchestration features share the same execution model. This
keeps the remaining agent-side changes aligned with the rest of the
batch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-22 13:50:01 +08:00
parent 90ea732584
commit 67ea3d2682
3 changed files with 342 additions and 54 deletions

View File

@@ -15,6 +15,77 @@ from app.agents.prompts import (
from app.agents.tools import ALL_TOOLS
from app.agents.skill_registry import build_skill_context
from app.services.llm_service import get_llm
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
import httpx
def _create_llm_from_config(config: dict):
"""根据用户模型配置创建 LLM 实例"""
provider = config.get("provider", "openai")
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
if provider == "openai" or provider == "deepseek" or provider == "custom":
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
return ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
return ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
def _get_llm_for_state(state: AgentState):
"""从 state 获取 LLM 实例,优先使用用户配置的模型"""
user_llm_config = state.get("user_llm_config")
if user_llm_config:
return _create_llm_from_config(user_llm_config)
return get_llm()
async def _ainvoke(llm, messages: list[BaseMessage]):
ainvoke = getattr(llm, "ainvoke", None)
if callable(ainvoke):
return await ainvoke(messages)
return await llm.invoke(messages)
async def _ainvoke_with_tools(llm, messages: list[BaseMessage]):
bound_llm = llm.bind_tools(ALL_TOOLS)
if hasattr(bound_llm, "ainvoke"):
return await bound_llm.ainvoke(messages)
return await bound_llm.invoke(messages)
def _compile_graph(graph: StateGraph, callbacks: list | None = None):
if callbacks:
try:
return graph.compile(callbacks=callbacks)
except TypeError as exc:
if "callbacks" not in str(exc):
raise
return graph.compile()
def _msg_type(msg: BaseMessage) -> str:
@@ -30,7 +101,7 @@ def _filter_user_messages(messages: list) -> list[BaseMessage]:
async def master_node(state: AgentState) -> AgentState:
"""主Agent节点: 理解用户意图决定调用哪个子Agent"""
llm = get_llm()
llm = _get_llm_for_state(state)
messages: list[BaseMessage] = state["messages"]
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
@@ -42,7 +113,7 @@ async def master_node(state: AgentState) -> AgentState:
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
)
response: AIMessage = await llm.invoke(system_msgs + messages)
response: AIMessage = await _ainvoke(llm,system_msgs + messages)
content = response.content.strip().lower()
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
@@ -66,7 +137,7 @@ async def master_node(state: AgentState) -> AgentState:
async def planner_node(state: AgentState) -> AgentState:
"""规划Agent节点: 制定计划,拆解任务步骤"""
llm = get_llm()
llm = _get_llm_for_state(state)
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
@@ -75,7 +146,7 @@ async def planner_node(state: AgentState) -> AgentState:
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.invoke(
response = await _ainvoke(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
@@ -94,7 +165,7 @@ async def planner_node(state: AgentState) -> AgentState:
async def executor_node(state: AgentState) -> AgentState:
"""执行Agent节点: 调用工具执行具体任务"""
llm = get_llm()
llm = _get_llm_for_state(state)
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
@@ -103,7 +174,7 @@ async def executor_node(state: AgentState) -> AgentState:
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.bind_tools(ALL_TOOLS).invoke(
response = await _ainvoke_with_tools(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
@@ -124,7 +195,7 @@ async def executor_node(state: AgentState) -> AgentState:
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
follow_up = await _ainvoke(llm,
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
@@ -138,7 +209,7 @@ async def executor_node(state: AgentState) -> AgentState:
async def librarian_node(state: AgentState) -> AgentState:
"""知识管理员节点: 管理知识库和知识图谱"""
llm = get_llm()
llm = _get_llm_for_state(state)
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
@@ -147,7 +218,7 @@ async def librarian_node(state: AgentState) -> AgentState:
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.bind_tools(ALL_TOOLS).invoke(
response = await _ainvoke_with_tools(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
@@ -168,7 +239,7 @@ async def librarian_node(state: AgentState) -> AgentState:
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
follow_up = await _ainvoke(llm,
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
@@ -183,7 +254,7 @@ async def librarian_node(state: AgentState) -> AgentState:
async def analyst_node(state: AgentState) -> AgentState:
"""分析师节点: 分析工作数据,生成报告"""
llm = get_llm()
llm = _get_llm_for_state(state)
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
@@ -192,7 +263,7 @@ async def analyst_node(state: AgentState) -> AgentState:
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
response = await llm.bind_tools(ALL_TOOLS).invoke(
response = await _ainvoke_with_tools(llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")]
)
@@ -213,7 +284,7 @@ async def analyst_node(state: AgentState) -> AgentState:
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
follow_up = await _ainvoke(llm,
[SystemMessage(content=ANALYST_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
@@ -261,7 +332,7 @@ def create_agent_graph(callbacks: list | None = None):
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
graph.add_edge(role.value, END)
return graph.compile(callbacks=callbacks)
return _compile_graph(graph, callbacks=callbacks)
_agent_graph = None