Files
X-Financial/server/src/app/services/steward_graph_action_runtime.py
caoxiaozhu 5311c99d69 refactor(server): steward 决策链路改用 LangGraph 编排
- 新增 StewardGraphPlannerService,用 LangGraph 状态图编排意图识别→流程判断→模型/规则分支→兜底,替代原 planner 内线性调用
- 新增 StewardGraphRuntimeService 编排运行时决策与槽位决策;StewardActionContracts/Executor 统一动作合约与执行
- steward_intent_agent/application_fact_resolver/runtime_chat 适配图执行器,config 暴露图相关开关
- pyproject/uv.lock 新增 langgraph 依赖
- 新增 graph_planner/graph_runtime/action_executor 测试,更新 intent_agent/planner/fact_resolver/runtime_chat/reimbursement 测试
2026-06-24 21:58:35 +08:00

205 lines
7.9 KiB
Python

from __future__ import annotations
from datetime import UTC, datetime
from typing import Any, TypedDict
from langgraph.graph import END, START, StateGraph
from sqlalchemy.orm import Session
from app.api.deps import CurrentUserContext
from app.models.agent_conversation import AgentConversation
from app.schemas.steward import StewardActionExecuteRequest, StewardActionExecuteResponse
from app.services.agent_conversations import AgentConversationService
from app.services.steward_action_executor import StewardActionExecutor
ACTION_CHECKPOINT_KEY = "steward_action_checkpoint"
TERMINAL_ACTION_STATUSES = {"succeeded", "blocked", "failed"}
class StewardGraphActionState(TypedDict, total=False):
request: StewardActionExecuteRequest
current_user: CurrentUserContext
conversation: AgentConversation | None
trace_id: str
existing_result: dict[str, Any]
response: StewardActionExecuteResponse
class StewardGraphActionRuntime:
"""用 LangGraph 包装小财管家白名单动作执行、checkpoint 和幂等重放。"""
def __init__(self, db: Session) -> None:
self.db = db
self.executor = StewardActionExecutor(db)
self._graph = self._build_graph()
def execute(
self,
request: StewardActionExecuteRequest,
current_user: CurrentUserContext,
) -> StewardActionExecuteResponse:
final_state = self._graph.invoke(
{
"request": request,
"current_user": current_user,
}
)
response = final_state.get("response")
if not isinstance(response, StewardActionExecuteResponse):
raise RuntimeError("LangGraph action runtime 未生成有效执行结果。")
return response
def _build_graph(self):
graph = StateGraph(StewardGraphActionState)
graph.add_node("action_checkpoint_load", self._load_checkpoint)
graph.add_node("action_execute_node", self._execute_action)
graph.add_node("action_checkpoint_persist", self._persist_checkpoint)
graph.add_edge(START, "action_checkpoint_load")
graph.add_conditional_edges(
"action_checkpoint_load",
self._route_after_checkpoint_load,
{
"replay": "action_checkpoint_persist",
"execute": "action_execute_node",
},
)
graph.add_edge("action_execute_node", "action_checkpoint_persist")
graph.add_edge("action_checkpoint_persist", END)
return graph.compile()
def _load_checkpoint(self, state: StewardGraphActionState) -> dict[str, Any]:
request = state["request"]
conversation = self._get_or_create_conversation(request, state["current_user"])
trace_id = self._resolve_trace_id(request)
if conversation is None or not trace_id:
return {
"conversation": conversation,
"trace_id": trace_id,
}
checkpoint = self._resolve_checkpoint(conversation)
existing = dict(checkpoint.get("actions", {}).get(trace_id) or {})
existing_response = existing.get("response")
status = str(existing.get("status") or "").strip()
if isinstance(existing_response, dict) and (
status in TERMINAL_ACTION_STATUSES
or (status == "needs_confirmation" and not request.confirmed)
):
response = StewardActionExecuteResponse.model_validate(existing_response)
response.result_payload = {
**dict(response.result_payload or {}),
"idempotent_replay": True,
}
response.trace = [
*list(response.trace or []),
self._trace("checkpoint_replay", client_trace_id=trace_id),
]
return {
"conversation": conversation,
"trace_id": trace_id,
"existing_result": existing,
"response": response,
}
return {
"conversation": conversation,
"trace_id": trace_id,
"existing_result": existing,
}
@staticmethod
def _route_after_checkpoint_load(state: StewardGraphActionState) -> str:
if isinstance(state.get("response"), StewardActionExecuteResponse):
return "replay"
return "execute"
def _execute_action(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]:
response = self.executor.execute(state["request"], state["current_user"])
return {"response": response}
def _persist_checkpoint(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]:
response = state["response"]
conversation = state.get("conversation")
trace_id = str(state.get("trace_id") or "").strip()
if conversation is None or not trace_id:
return {"response": response}
checkpoint = self._resolve_checkpoint(conversation)
actions = dict(checkpoint.get("actions") or {})
request = state["request"]
response_payload = response.model_dump(mode="json")
actions[trace_id] = {
"client_trace_id": trace_id,
"action_type": response.action_type,
"status": response.status,
"request": request.model_dump(mode="json"),
"response": response_payload,
"updated_at": datetime.now(UTC).isoformat(),
}
checkpoint["actions"] = actions
if response.status == "needs_confirmation":
checkpoint["pending_interrupt"] = {
"client_trace_id": trace_id,
"action_type": response.action_type,
"message": response.message,
"requires_confirmation": True,
"updated_at": datetime.now(UTC).isoformat(),
}
elif str(checkpoint.get("pending_interrupt", {}).get("client_trace_id") or "") == trace_id:
checkpoint["pending_interrupt"] = {}
conversation.state_json = {
**dict(conversation.state_json or {}),
ACTION_CHECKPOINT_KEY: checkpoint,
}
conversation.updated_at = datetime.now(UTC)
self.db.add(conversation)
self.db.commit()
return {"response": response}
def _get_or_create_conversation(
self,
request: StewardActionExecuteRequest,
current_user: CurrentUserContext,
) -> AgentConversation | None:
conversation_id = str(request.conversation_id or "").strip()
if not conversation_id:
return None
return AgentConversationService(self.db).get_or_create_conversation(
conversation_id=conversation_id,
user_id=current_user.username,
source="user_message",
context_json={
"session_type": "steward",
"entry_source": "steward_action_executor",
"steward_state": dict((request.context_json or {}).get("steward_state") or {}),
},
)
@staticmethod
def _resolve_checkpoint(conversation: AgentConversation) -> dict[str, Any]:
checkpoint = dict((conversation.state_json or {}).get(ACTION_CHECKPOINT_KEY) or {})
checkpoint.setdefault("actions", {})
checkpoint.setdefault("pending_interrupt", {})
return checkpoint
@staticmethod
def _resolve_trace_id(request: StewardActionExecuteRequest) -> str:
trace_id = str(request.client_trace_id or "").strip()
if trace_id:
return trace_id
action_type = str(request.action_type or "").strip()
task_id = str(request.task.task_id if request.task is not None else "").strip()
if action_type and task_id:
return f"{action_type}:{task_id}"
return ""
@staticmethod
def _trace(stage: str, **extra: Any) -> dict[str, Any]:
return {
"stage": stage,
"at": datetime.now(UTC).isoformat(),
**extra,
}