205 lines
7.9 KiB
Python
205 lines
7.9 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from datetime import UTC, datetime
|
||
|
|
from typing import Any, TypedDict
|
||
|
|
|
||
|
|
from langgraph.graph import END, START, StateGraph
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.api.deps import CurrentUserContext
|
||
|
|
from app.models.agent_conversation import AgentConversation
|
||
|
|
from app.schemas.steward import StewardActionExecuteRequest, StewardActionExecuteResponse
|
||
|
|
from app.services.agent_conversations import AgentConversationService
|
||
|
|
from app.services.steward_action_executor import StewardActionExecutor
|
||
|
|
|
||
|
|
ACTION_CHECKPOINT_KEY = "steward_action_checkpoint"
|
||
|
|
TERMINAL_ACTION_STATUSES = {"succeeded", "blocked", "failed"}
|
||
|
|
|
||
|
|
|
||
|
|
class StewardGraphActionState(TypedDict, total=False):
|
||
|
|
request: StewardActionExecuteRequest
|
||
|
|
current_user: CurrentUserContext
|
||
|
|
conversation: AgentConversation | None
|
||
|
|
trace_id: str
|
||
|
|
existing_result: dict[str, Any]
|
||
|
|
response: StewardActionExecuteResponse
|
||
|
|
|
||
|
|
|
||
|
|
class StewardGraphActionRuntime:
|
||
|
|
"""用 LangGraph 包装小财管家白名单动作执行、checkpoint 和幂等重放。"""
|
||
|
|
|
||
|
|
def __init__(self, db: Session) -> None:
|
||
|
|
self.db = db
|
||
|
|
self.executor = StewardActionExecutor(db)
|
||
|
|
self._graph = self._build_graph()
|
||
|
|
|
||
|
|
def execute(
|
||
|
|
self,
|
||
|
|
request: StewardActionExecuteRequest,
|
||
|
|
current_user: CurrentUserContext,
|
||
|
|
) -> StewardActionExecuteResponse:
|
||
|
|
final_state = self._graph.invoke(
|
||
|
|
{
|
||
|
|
"request": request,
|
||
|
|
"current_user": current_user,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
response = final_state.get("response")
|
||
|
|
if not isinstance(response, StewardActionExecuteResponse):
|
||
|
|
raise RuntimeError("LangGraph action runtime 未生成有效执行结果。")
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _build_graph(self):
|
||
|
|
graph = StateGraph(StewardGraphActionState)
|
||
|
|
graph.add_node("action_checkpoint_load", self._load_checkpoint)
|
||
|
|
graph.add_node("action_execute_node", self._execute_action)
|
||
|
|
graph.add_node("action_checkpoint_persist", self._persist_checkpoint)
|
||
|
|
|
||
|
|
graph.add_edge(START, "action_checkpoint_load")
|
||
|
|
graph.add_conditional_edges(
|
||
|
|
"action_checkpoint_load",
|
||
|
|
self._route_after_checkpoint_load,
|
||
|
|
{
|
||
|
|
"replay": "action_checkpoint_persist",
|
||
|
|
"execute": "action_execute_node",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
graph.add_edge("action_execute_node", "action_checkpoint_persist")
|
||
|
|
graph.add_edge("action_checkpoint_persist", END)
|
||
|
|
return graph.compile()
|
||
|
|
|
||
|
|
def _load_checkpoint(self, state: StewardGraphActionState) -> dict[str, Any]:
|
||
|
|
request = state["request"]
|
||
|
|
conversation = self._get_or_create_conversation(request, state["current_user"])
|
||
|
|
trace_id = self._resolve_trace_id(request)
|
||
|
|
if conversation is None or not trace_id:
|
||
|
|
return {
|
||
|
|
"conversation": conversation,
|
||
|
|
"trace_id": trace_id,
|
||
|
|
}
|
||
|
|
|
||
|
|
checkpoint = self._resolve_checkpoint(conversation)
|
||
|
|
existing = dict(checkpoint.get("actions", {}).get(trace_id) or {})
|
||
|
|
existing_response = existing.get("response")
|
||
|
|
status = str(existing.get("status") or "").strip()
|
||
|
|
if isinstance(existing_response, dict) and (
|
||
|
|
status in TERMINAL_ACTION_STATUSES
|
||
|
|
or (status == "needs_confirmation" and not request.confirmed)
|
||
|
|
):
|
||
|
|
response = StewardActionExecuteResponse.model_validate(existing_response)
|
||
|
|
response.result_payload = {
|
||
|
|
**dict(response.result_payload or {}),
|
||
|
|
"idempotent_replay": True,
|
||
|
|
}
|
||
|
|
response.trace = [
|
||
|
|
*list(response.trace or []),
|
||
|
|
self._trace("checkpoint_replay", client_trace_id=trace_id),
|
||
|
|
]
|
||
|
|
return {
|
||
|
|
"conversation": conversation,
|
||
|
|
"trace_id": trace_id,
|
||
|
|
"existing_result": existing,
|
||
|
|
"response": response,
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
"conversation": conversation,
|
||
|
|
"trace_id": trace_id,
|
||
|
|
"existing_result": existing,
|
||
|
|
}
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _route_after_checkpoint_load(state: StewardGraphActionState) -> str:
|
||
|
|
if isinstance(state.get("response"), StewardActionExecuteResponse):
|
||
|
|
return "replay"
|
||
|
|
return "execute"
|
||
|
|
|
||
|
|
def _execute_action(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]:
|
||
|
|
response = self.executor.execute(state["request"], state["current_user"])
|
||
|
|
return {"response": response}
|
||
|
|
|
||
|
|
def _persist_checkpoint(self, state: StewardGraphActionState) -> dict[str, StewardActionExecuteResponse]:
|
||
|
|
response = state["response"]
|
||
|
|
conversation = state.get("conversation")
|
||
|
|
trace_id = str(state.get("trace_id") or "").strip()
|
||
|
|
if conversation is None or not trace_id:
|
||
|
|
return {"response": response}
|
||
|
|
|
||
|
|
checkpoint = self._resolve_checkpoint(conversation)
|
||
|
|
actions = dict(checkpoint.get("actions") or {})
|
||
|
|
request = state["request"]
|
||
|
|
response_payload = response.model_dump(mode="json")
|
||
|
|
actions[trace_id] = {
|
||
|
|
"client_trace_id": trace_id,
|
||
|
|
"action_type": response.action_type,
|
||
|
|
"status": response.status,
|
||
|
|
"request": request.model_dump(mode="json"),
|
||
|
|
"response": response_payload,
|
||
|
|
"updated_at": datetime.now(UTC).isoformat(),
|
||
|
|
}
|
||
|
|
checkpoint["actions"] = actions
|
||
|
|
if response.status == "needs_confirmation":
|
||
|
|
checkpoint["pending_interrupt"] = {
|
||
|
|
"client_trace_id": trace_id,
|
||
|
|
"action_type": response.action_type,
|
||
|
|
"message": response.message,
|
||
|
|
"requires_confirmation": True,
|
||
|
|
"updated_at": datetime.now(UTC).isoformat(),
|
||
|
|
}
|
||
|
|
elif str(checkpoint.get("pending_interrupt", {}).get("client_trace_id") or "") == trace_id:
|
||
|
|
checkpoint["pending_interrupt"] = {}
|
||
|
|
|
||
|
|
conversation.state_json = {
|
||
|
|
**dict(conversation.state_json or {}),
|
||
|
|
ACTION_CHECKPOINT_KEY: checkpoint,
|
||
|
|
}
|
||
|
|
conversation.updated_at = datetime.now(UTC)
|
||
|
|
self.db.add(conversation)
|
||
|
|
self.db.commit()
|
||
|
|
return {"response": response}
|
||
|
|
|
||
|
|
def _get_or_create_conversation(
|
||
|
|
self,
|
||
|
|
request: StewardActionExecuteRequest,
|
||
|
|
current_user: CurrentUserContext,
|
||
|
|
) -> AgentConversation | None:
|
||
|
|
conversation_id = str(request.conversation_id or "").strip()
|
||
|
|
if not conversation_id:
|
||
|
|
return None
|
||
|
|
return AgentConversationService(self.db).get_or_create_conversation(
|
||
|
|
conversation_id=conversation_id,
|
||
|
|
user_id=current_user.username,
|
||
|
|
source="user_message",
|
||
|
|
context_json={
|
||
|
|
"session_type": "steward",
|
||
|
|
"entry_source": "steward_action_executor",
|
||
|
|
"steward_state": dict((request.context_json or {}).get("steward_state") or {}),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _resolve_checkpoint(conversation: AgentConversation) -> dict[str, Any]:
|
||
|
|
checkpoint = dict((conversation.state_json or {}).get(ACTION_CHECKPOINT_KEY) or {})
|
||
|
|
checkpoint.setdefault("actions", {})
|
||
|
|
checkpoint.setdefault("pending_interrupt", {})
|
||
|
|
return checkpoint
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _resolve_trace_id(request: StewardActionExecuteRequest) -> str:
|
||
|
|
trace_id = str(request.client_trace_id or "").strip()
|
||
|
|
if trace_id:
|
||
|
|
return trace_id
|
||
|
|
action_type = str(request.action_type or "").strip()
|
||
|
|
task_id = str(request.task.task_id if request.task is not None else "").strip()
|
||
|
|
if action_type and task_id:
|
||
|
|
return f"{action_type}:{task_id}"
|
||
|
|
return ""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _trace(stage: str, **extra: Any) -> dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"stage": stage,
|
||
|
|
"at": datetime.now(UTC).isoformat(),
|
||
|
|
**extra,
|
||
|
|
}
|