from __future__ import annotations from datetime import UTC, datetime from typing import Any, TypedDict from langgraph.graph import END, START, StateGraph from sqlalchemy.orm import Session from app.api.deps import CurrentUserContext from app.models.agent_conversation import AgentConversation from app.schemas.steward import StewardActionExecuteRequest, StewardActionExecuteResponse from app.services.agent_conversations import AgentConversationService from app.services.steward_action_executor import StewardActionExecutor ACTION_CHECKPOINT_KEY = "steward_action_checkpoint" TERMINAL_ACTION_STATUSES = {"succeeded", "blocked", "failed"} class StewardGraphActionState(TypedDict, total=False): request: StewardActionExecuteRequest current_user: CurrentUserContext conversation: AgentConversation | None trace_id: str existing_result: dict[str, Any] response: StewardActionExecuteResponse class StewardGraphActionRuntime: """用 LangGraph 包装小财管家白名单动作执行、checkpoint 和幂等重放。""" def __init__(self, db: Session) -> None: self.db = db self.executor = StewardActionExecutor(db) self._graph = self._build_graph() def execute( self, request: StewardActionExecuteRequest, current_user: CurrentUserContext, ) -> StewardActionExecuteResponse: final_state = self._graph.invoke( { "request": request, "current_user": current_user, } ) response = final_state.get("response") if not isinstance(response, StewardActionExecuteResponse): raise RuntimeError("LangGraph action runtime 未生成有效执行结果。") return response def _build_graph(self): graph = StateGraph(StewardGraphActionState) graph.add_node("action_checkpoint_load", self._load_checkpoint) graph.add_node("action_execute_node", self._execute_action) graph.add_node("action_checkpoint_persist", self._persist_checkpoint) graph.add_edge(START, "action_checkpoint_load") graph.add_conditional_edges( "action_checkpoint_load", self._route_after_checkpoint_load, { "replay": "action_checkpoint_persist", "execute": "action_execute_node", }, ) graph.add_edge("action_execute_node", "action_checkpoint_persist") graph.add_edge("action_checkpoint_persist", END) return graph.compile() def _load_checkpoint(self, state: StewardGraphActionState) -> dict[str, Any]: request = state["request"] conversation = self._get_or_create_conversation(request, state["current_user"]) trace_id = self._resolve_trace_id(request) if conversation is None or not trace_id: return { "conversation": conversation, "trace_id": trace_id, } checkpoint = self._resolve_checkpoint(conversation) existing = dict(checkpoint.get("actions", {}).get(trace_id) or {}) existing_response = existing.get("response") status = str(existing.get("status") or "").strip() if isinstance(existing_response, dict) and ( status in TERMINAL_ACTION_STATUSES or (status == "needs_confirmation" and not request.confirmed) ): response = StewardActionExecuteResponse.model_validate(existing_response) response.result_payload = { **dict(response.result_payload or {}), "idempotent_replay": True, } response.trace = [ *list(response.trace or []), self._trace("checkpoint_replay", client_trace_id=trace_id), ] return { "conversation": conversation, "trace_id": trace_id, "existing_result": existing, "response": response, } return { "conversation": conversation, "trace_id": trace_id, "existing_result": existing, } @staticmethod def _route_after_checkpoint_load(state: StewardGraphActionState) -> str: if isinstance(state.get("response"), StewardActionExecuteResponse): return "replay" return "execute" def _execute_action(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]: response = self.executor.execute(state["request"], state["current_user"]) return {"response": response} def _persist_checkpoint(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]: response = state["response"] conversation = state.get("conversation") trace_id = str(state.get("trace_id") or "").strip() if conversation is None or not trace_id: return {"response": response} checkpoint = self._resolve_checkpoint(conversation) actions = dict(checkpoint.get("actions") or {}) request = state["request"] response_payload = response.model_dump(mode="json") actions[trace_id] = { "client_trace_id": trace_id, "action_type": response.action_type, "status": response.status, "request": request.model_dump(mode="json"), "response": response_payload, "updated_at": datetime.now(UTC).isoformat(), } checkpoint["actions"] = actions if response.status == "needs_confirmation": checkpoint["pending_interrupt"] = { "client_trace_id": trace_id, "action_type": response.action_type, "message": response.message, "requires_confirmation": True, "updated_at": datetime.now(UTC).isoformat(), } elif str(checkpoint.get("pending_interrupt", {}).get("client_trace_id") or "") == trace_id: checkpoint["pending_interrupt"] = {} conversation.state_json = { **dict(conversation.state_json or {}), ACTION_CHECKPOINT_KEY: checkpoint, } conversation.updated_at = datetime.now(UTC) self.db.add(conversation) self.db.commit() return {"response": response} def _get_or_create_conversation( self, request: StewardActionExecuteRequest, current_user: CurrentUserContext, ) -> AgentConversation | None: conversation_id = str(request.conversation_id or "").strip() if not conversation_id: return None return AgentConversationService(self.db).get_or_create_conversation( conversation_id=conversation_id, user_id=current_user.username, source="user_message", context_json={ "session_type": "steward", "entry_source": "steward_action_executor", "steward_state": dict((request.context_json or {}).get("steward_state") or {}), }, ) @staticmethod def _resolve_checkpoint(conversation: AgentConversation) -> dict[str, Any]: checkpoint = dict((conversation.state_json or {}).get(ACTION_CHECKPOINT_KEY) or {}) checkpoint.setdefault("actions", {}) checkpoint.setdefault("pending_interrupt", {}) return checkpoint @staticmethod def _resolve_trace_id(request: StewardActionExecuteRequest) -> str: trace_id = str(request.client_trace_id or "").strip() if trace_id: return trace_id action_type = str(request.action_type or "").strip() task_id = str(request.task.task_id if request.task is not None else "").strip() if action_type and task_id: return f"{action_type}:{task_id}" return "" @staticmethod def _trace(stage: str, **extra: Any) -> dict[str, Any]: return { "stage": stage, "at": datetime.now(UTC).isoformat(), **extra, }