refactor(server): steward 决策链路改用 LangGraph 编排
- 新增 StewardGraphPlannerService,用 LangGraph 状态图编排意图识别→流程判断→模型/规则分支→兜底,替代原 planner 内线性调用 - 新增 StewardGraphRuntimeService 编排运行时决策与槽位决策;StewardActionContracts/Executor 统一动作合约与执行 - steward_intent_agent/application_fact_resolver/runtime_chat 适配图执行器,config 暴露图相关开关 - pyproject/uv.lock 新增 langgraph 依赖 - 新增 graph_planner/graph_runtime/action_executor 测试,更新 intent_agent/planner/fact_resolver/runtime_chat/reimbursement 测试
This commit is contained in:
204
server/src/app/services/steward_graph_action_runtime.py
Normal file
204
server/src/app/services/steward_graph_action_runtime.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import CurrentUserContext
|
||||
from app.models.agent_conversation import AgentConversation
|
||||
from app.schemas.steward import StewardActionExecuteRequest, StewardActionExecuteResponse
|
||||
from app.services.agent_conversations import AgentConversationService
|
||||
from app.services.steward_action_executor import StewardActionExecutor
|
||||
|
||||
ACTION_CHECKPOINT_KEY = "steward_action_checkpoint"
|
||||
TERMINAL_ACTION_STATUSES = {"succeeded", "blocked", "failed"}
|
||||
|
||||
|
||||
class StewardGraphActionState(TypedDict, total=False):
|
||||
request: StewardActionExecuteRequest
|
||||
current_user: CurrentUserContext
|
||||
conversation: AgentConversation | None
|
||||
trace_id: str
|
||||
existing_result: dict[str, Any]
|
||||
response: StewardActionExecuteResponse
|
||||
|
||||
|
||||
class StewardGraphActionRuntime:
|
||||
"""用 LangGraph 包装小财管家白名单动作执行、checkpoint 和幂等重放。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.executor = StewardActionExecutor(db)
|
||||
self._graph = self._build_graph()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
request: StewardActionExecuteRequest,
|
||||
current_user: CurrentUserContext,
|
||||
) -> StewardActionExecuteResponse:
|
||||
final_state = self._graph.invoke(
|
||||
{
|
||||
"request": request,
|
||||
"current_user": current_user,
|
||||
}
|
||||
)
|
||||
response = final_state.get("response")
|
||||
if not isinstance(response, StewardActionExecuteResponse):
|
||||
raise RuntimeError("LangGraph action runtime 未生成有效执行结果。")
|
||||
return response
|
||||
|
||||
def _build_graph(self):
|
||||
graph = StateGraph(StewardGraphActionState)
|
||||
graph.add_node("action_checkpoint_load", self._load_checkpoint)
|
||||
graph.add_node("action_execute_node", self._execute_action)
|
||||
graph.add_node("action_checkpoint_persist", self._persist_checkpoint)
|
||||
|
||||
graph.add_edge(START, "action_checkpoint_load")
|
||||
graph.add_conditional_edges(
|
||||
"action_checkpoint_load",
|
||||
self._route_after_checkpoint_load,
|
||||
{
|
||||
"replay": "action_checkpoint_persist",
|
||||
"execute": "action_execute_node",
|
||||
},
|
||||
)
|
||||
graph.add_edge("action_execute_node", "action_checkpoint_persist")
|
||||
graph.add_edge("action_checkpoint_persist", END)
|
||||
return graph.compile()
|
||||
|
||||
def _load_checkpoint(self, state: StewardGraphActionState) -> dict[str, Any]:
|
||||
request = state["request"]
|
||||
conversation = self._get_or_create_conversation(request, state["current_user"])
|
||||
trace_id = self._resolve_trace_id(request)
|
||||
if conversation is None or not trace_id:
|
||||
return {
|
||||
"conversation": conversation,
|
||||
"trace_id": trace_id,
|
||||
}
|
||||
|
||||
checkpoint = self._resolve_checkpoint(conversation)
|
||||
existing = dict(checkpoint.get("actions", {}).get(trace_id) or {})
|
||||
existing_response = existing.get("response")
|
||||
status = str(existing.get("status") or "").strip()
|
||||
if isinstance(existing_response, dict) and (
|
||||
status in TERMINAL_ACTION_STATUSES
|
||||
or (status == "needs_confirmation" and not request.confirmed)
|
||||
):
|
||||
response = StewardActionExecuteResponse.model_validate(existing_response)
|
||||
response.result_payload = {
|
||||
**dict(response.result_payload or {}),
|
||||
"idempotent_replay": True,
|
||||
}
|
||||
response.trace = [
|
||||
*list(response.trace or []),
|
||||
self._trace("checkpoint_replay", client_trace_id=trace_id),
|
||||
]
|
||||
return {
|
||||
"conversation": conversation,
|
||||
"trace_id": trace_id,
|
||||
"existing_result": existing,
|
||||
"response": response,
|
||||
}
|
||||
|
||||
return {
|
||||
"conversation": conversation,
|
||||
"trace_id": trace_id,
|
||||
"existing_result": existing,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _route_after_checkpoint_load(state: StewardGraphActionState) -> str:
|
||||
if isinstance(state.get("response"), StewardActionExecuteResponse):
|
||||
return "replay"
|
||||
return "execute"
|
||||
|
||||
def _execute_action(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]:
|
||||
response = self.executor.execute(state["request"], state["current_user"])
|
||||
return {"response": response}
|
||||
|
||||
def _persist_checkpoint(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]:
|
||||
response = state["response"]
|
||||
conversation = state.get("conversation")
|
||||
trace_id = str(state.get("trace_id") or "").strip()
|
||||
if conversation is None or not trace_id:
|
||||
return {"response": response}
|
||||
|
||||
checkpoint = self._resolve_checkpoint(conversation)
|
||||
actions = dict(checkpoint.get("actions") or {})
|
||||
request = state["request"]
|
||||
response_payload = response.model_dump(mode="json")
|
||||
actions[trace_id] = {
|
||||
"client_trace_id": trace_id,
|
||||
"action_type": response.action_type,
|
||||
"status": response.status,
|
||||
"request": request.model_dump(mode="json"),
|
||||
"response": response_payload,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
checkpoint["actions"] = actions
|
||||
if response.status == "needs_confirmation":
|
||||
checkpoint["pending_interrupt"] = {
|
||||
"client_trace_id": trace_id,
|
||||
"action_type": response.action_type,
|
||||
"message": response.message,
|
||||
"requires_confirmation": True,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
elif str(checkpoint.get("pending_interrupt", {}).get("client_trace_id") or "") == trace_id:
|
||||
checkpoint["pending_interrupt"] = {}
|
||||
|
||||
conversation.state_json = {
|
||||
**dict(conversation.state_json or {}),
|
||||
ACTION_CHECKPOINT_KEY: checkpoint,
|
||||
}
|
||||
conversation.updated_at = datetime.now(UTC)
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
return {"response": response}
|
||||
|
||||
def _get_or_create_conversation(
|
||||
self,
|
||||
request: StewardActionExecuteRequest,
|
||||
current_user: CurrentUserContext,
|
||||
) -> AgentConversation | None:
|
||||
conversation_id = str(request.conversation_id or "").strip()
|
||||
if not conversation_id:
|
||||
return None
|
||||
return AgentConversationService(self.db).get_or_create_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_id=current_user.username,
|
||||
source="user_message",
|
||||
context_json={
|
||||
"session_type": "steward",
|
||||
"entry_source": "steward_action_executor",
|
||||
"steward_state": dict((request.context_json or {}).get("steward_state") or {}),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_checkpoint(conversation: AgentConversation) -> dict[str, Any]:
|
||||
checkpoint = dict((conversation.state_json or {}).get(ACTION_CHECKPOINT_KEY) or {})
|
||||
checkpoint.setdefault("actions", {})
|
||||
checkpoint.setdefault("pending_interrupt", {})
|
||||
return checkpoint
|
||||
|
||||
@staticmethod
|
||||
def _resolve_trace_id(request: StewardActionExecuteRequest) -> str:
|
||||
trace_id = str(request.client_trace_id or "").strip()
|
||||
if trace_id:
|
||||
return trace_id
|
||||
action_type = str(request.action_type or "").strip()
|
||||
task_id = str(request.task.task_id if request.task is not None else "").strip()
|
||||
if action_type and task_id:
|
||||
return f"{action_type}:{task_id}"
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _trace(stage: str, **extra: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"stage": stage,
|
||||
"at": datetime.now(UTC).isoformat(),
|
||||
**extra,
|
||||
}
|
||||
Reference in New Issue
Block a user