Refine travel reimbursement steward flow
Align planner, runtime rules, and policy assets so travel guidance matches the updated reimbursement workflow.
This commit is contained in:
@@ -1,19 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.steward import (
|
||||
StewardFlowStatePatch,
|
||||
StewardRuntimeDecisionRequest,
|
||||
StewardRuntimeDecisionResponse,
|
||||
)
|
||||
from app.services.runtime_chat import RuntimeChatService
|
||||
from app.services.steward_flow_state import StewardFlowStateService
|
||||
|
||||
|
||||
STEWARD_RUNTIME_DECISION_FUNCTION_NAME = "submit_steward_runtime_decision"
|
||||
|
||||
RUNTIME_NEXT_ACTIONS = {
|
||||
"plan_new_tasks",
|
||||
"continue_selected_flow",
|
||||
"submit_current_application",
|
||||
"continue_next_task",
|
||||
"fill_current_slot",
|
||||
@@ -22,6 +26,16 @@ RUNTIME_NEXT_ACTIONS = {
|
||||
"no_op",
|
||||
}
|
||||
|
||||
FIELD_LABELS = {
|
||||
"transport_mode": "出行方式",
|
||||
"expense_type": "费用类型",
|
||||
"time_range": "时间",
|
||||
"location": "地点",
|
||||
"reason": "事由",
|
||||
"amount": "金额",
|
||||
"attachments": "附件",
|
||||
}
|
||||
|
||||
|
||||
class StewardRuntimeDecisionAgent:
|
||||
"""用小财管家运行时上下文判断用户当前输入应落到哪个等待动作。"""
|
||||
@@ -31,6 +45,9 @@ class StewardRuntimeDecisionAgent:
|
||||
|
||||
def decide(self, request: StewardRuntimeDecisionRequest) -> StewardRuntimeDecisionResponse:
|
||||
normalized_request = self._normalize_request(request)
|
||||
selected_flow_decision = self._build_selected_flow_decision(normalized_request, [])
|
||||
if selected_flow_decision is not None:
|
||||
return selected_flow_decision
|
||||
result = self.runtime_chat_service.complete_with_tool_call(
|
||||
self._build_messages(normalized_request),
|
||||
tools=[self._build_tool_schema()],
|
||||
@@ -47,18 +64,104 @@ class StewardRuntimeDecisionAgent:
|
||||
if result.tool_call is not None and result.tool_call.name == STEWARD_RUNTIME_DECISION_FUNCTION_NAME:
|
||||
response = self._build_response_from_model_payload(result.tool_call.arguments, normalized_request, traces)
|
||||
if response is not None:
|
||||
return response
|
||||
return self._build_rule_fallback(normalized_request, traces)
|
||||
return self._attach_updated_steward_state(response, normalized_request)
|
||||
return self._attach_updated_steward_state(
|
||||
self._build_rule_fallback(normalized_request, traces),
|
||||
normalized_request,
|
||||
)
|
||||
|
||||
def _build_selected_flow_decision(
|
||||
self,
|
||||
request: StewardRuntimeDecisionRequest,
|
||||
traces: list[dict[str, Any]],
|
||||
) -> StewardRuntimeDecisionResponse | None:
|
||||
selected_flow_id = self._resolve_selected_pending_flow_id(
|
||||
request.runtime_state,
|
||||
request.user_message,
|
||||
)
|
||||
if not selected_flow_id:
|
||||
return None
|
||||
next_state = StewardFlowStateService().confirm_flow(
|
||||
request.runtime_state.get("steward_state") if isinstance(request.runtime_state.get("steward_state"), dict) else {},
|
||||
selected_flow_id,
|
||||
)
|
||||
return StewardRuntimeDecisionResponse(
|
||||
decision_source="rule_fallback",
|
||||
next_action="continue_selected_flow",
|
||||
target_task_id=selected_flow_id,
|
||||
response_text=self._build_selected_flow_response_text(selected_flow_id),
|
||||
rationale="已按你选择的候选流程继续处理。",
|
||||
steward_state=next_state,
|
||||
model_call_traces=traces,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_request(request: StewardRuntimeDecisionRequest) -> StewardRuntimeDecisionRequest:
|
||||
context_json = request.context_json if isinstance(request.context_json, dict) else {}
|
||||
runtime_state = request.runtime_state if isinstance(request.runtime_state, dict) else {}
|
||||
return StewardRuntimeDecisionRequest(
|
||||
user_message=str(request.user_message or "").strip(),
|
||||
session_type=str(request.session_type or "steward").strip() or "steward",
|
||||
runtime_state=request.runtime_state if isinstance(request.runtime_state, dict) else {},
|
||||
context_json=request.context_json if isinstance(request.context_json, dict) else {},
|
||||
runtime_state=StewardRuntimeDecisionAgent._hydrate_runtime_state(runtime_state, context_json),
|
||||
context_json=context_json,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _hydrate_runtime_state(
|
||||
runtime_state: dict[str, Any],
|
||||
context_json: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
hydrated = dict(runtime_state or {})
|
||||
steward_state = StewardRuntimeDecisionAgent._resolve_steward_state(context_json)
|
||||
if steward_state:
|
||||
hydrated.setdefault("steward_state", steward_state)
|
||||
if StewardRuntimeDecisionAgent._has_runtime_anchor(hydrated) or not steward_state:
|
||||
return hydrated
|
||||
|
||||
active_flow = str(steward_state.get("active_flow") or "").strip()
|
||||
flows = steward_state.get("flows") if isinstance(steward_state.get("flows"), dict) else {}
|
||||
flow = flows.get(active_flow) if isinstance(flows, dict) else None
|
||||
if not isinstance(flow, dict):
|
||||
return hydrated
|
||||
|
||||
missing_fields = [
|
||||
str(item or "").strip()
|
||||
for item in list(flow.get("missing_fields") or [])
|
||||
if str(item or "").strip()
|
||||
]
|
||||
hydrated["current_task"] = {
|
||||
"task_id": active_flow,
|
||||
"task_type": "expense_application" if active_flow == "travel_application" else "reimbursement",
|
||||
"ontology_fields": dict(flow.get("fields") or {}),
|
||||
"missing_fields": missing_fields,
|
||||
}
|
||||
if missing_fields:
|
||||
hydrated["waiting_for"] = "steward_flow_field_completion"
|
||||
else:
|
||||
hydrated["waiting_for"] = "steward_flow_confirmation"
|
||||
return hydrated
|
||||
|
||||
@staticmethod
|
||||
def _resolve_steward_state(context_json: dict[str, Any]) -> dict[str, Any]:
|
||||
direct_state = context_json.get("steward_state") or context_json.get("stewardState")
|
||||
if isinstance(direct_state, dict) and direct_state:
|
||||
return direct_state
|
||||
conversation_state = context_json.get("conversation_state")
|
||||
if isinstance(conversation_state, dict):
|
||||
nested_state = conversation_state.get("steward_state") or conversation_state.get("stewardState")
|
||||
if isinstance(nested_state, dict) and nested_state:
|
||||
return nested_state
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _has_runtime_anchor(runtime_state: dict[str, Any]) -> bool:
|
||||
if str(runtime_state.get("waiting_for") or "").strip():
|
||||
return True
|
||||
for key in ("pending_application", "pending_steward_action", "pending_slot_action", "current_task"):
|
||||
if isinstance(runtime_state.get(key), dict) and runtime_state[key]:
|
||||
return True
|
||||
return bool(runtime_state.get("remaining_tasks") or runtime_state.get("completed_tasks"))
|
||||
|
||||
@staticmethod
|
||||
def _build_messages(request: StewardRuntimeDecisionRequest) -> list[dict[str, Any]]:
|
||||
payload = {
|
||||
@@ -177,6 +280,34 @@ class StewardRuntimeDecisionAgent:
|
||||
rationale="模型运行时决策暂不可用,我先按当前待确认的下一项任务继续处理。",
|
||||
model_call_traces=traces,
|
||||
)
|
||||
if waiting_for == "steward_flow_field_completion":
|
||||
current_task = state.get("current_task") if isinstance(state.get("current_task"), dict) else {}
|
||||
missing_fields = [
|
||||
str(item or "").strip()
|
||||
for item in list(current_task.get("missing_fields") or [])
|
||||
if str(item or "").strip()
|
||||
]
|
||||
field_key = missing_fields[0] if missing_fields else ""
|
||||
if field_key and request.user_message:
|
||||
return StewardRuntimeDecisionResponse(
|
||||
decision_source="rule_fallback",
|
||||
next_action="fill_current_slot",
|
||||
target_task_id=str(current_task.get("task_id") or ""),
|
||||
field_key=field_key,
|
||||
field_value=request.user_message,
|
||||
rationale="模型运行时决策暂不可用,我先把你的补充写入当前小财管家流程字段。",
|
||||
model_call_traces=traces,
|
||||
)
|
||||
if field_key:
|
||||
return StewardRuntimeDecisionResponse(
|
||||
decision_source="rule_fallback",
|
||||
next_action="ask_user",
|
||||
target_task_id=str(current_task.get("task_id") or ""),
|
||||
field_key=field_key,
|
||||
question=f"请补充{FIELD_LABELS.get(field_key, field_key)}。",
|
||||
rationale="当前小财管家流程仍缺少必要字段。",
|
||||
model_call_traces=traces,
|
||||
)
|
||||
if waiting_for:
|
||||
return StewardRuntimeDecisionResponse(
|
||||
decision_source="rule_fallback",
|
||||
@@ -192,6 +323,104 @@ class StewardRuntimeDecisionAgent:
|
||||
model_call_traces=traces,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_selected_pending_flow_id(runtime_state: dict[str, Any], user_message: str) -> str:
|
||||
steward_state = runtime_state.get("steward_state")
|
||||
if not isinstance(steward_state, dict):
|
||||
return ""
|
||||
pending = steward_state.get("pending_flow_confirmation")
|
||||
if not isinstance(pending, dict) or pending.get("status") != "pending":
|
||||
return ""
|
||||
message = re.sub(r"\s+", "", str(user_message or ""))
|
||||
if not message:
|
||||
return ""
|
||||
candidates = pending.get("candidate_flows") if isinstance(pending.get("candidate_flows"), list) else []
|
||||
for candidate in candidates:
|
||||
if not isinstance(candidate, dict):
|
||||
continue
|
||||
flow_id = str(candidate.get("flow_id") or "").strip()
|
||||
label = re.sub(r"\s+", "", str(candidate.get("label") or ""))
|
||||
if flow_id == "travel_application" and (
|
||||
message in {"补办出差申请", "出差申请", "申请", "补申请"}
|
||||
or (label and message == label)
|
||||
):
|
||||
return flow_id
|
||||
if flow_id == "travel_reimbursement" and (
|
||||
message in {"发起费用报销", "费用报销", "报销", "发起报销"}
|
||||
or (label and message == label)
|
||||
):
|
||||
return flow_id
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _build_selected_flow_response_text(flow_id: str) -> str:
|
||||
if flow_id == "travel_application":
|
||||
return "已确认按 **补办出差申请** 继续,我会基于当前出差信息整理申请材料。"
|
||||
return "已确认按 **发起费用报销** 继续,我会基于当前出差信息整理报销材料。"
|
||||
|
||||
@staticmethod
|
||||
def _clean_text(value: Any) -> str:
|
||||
return str(value or "").strip()
|
||||
|
||||
def _attach_updated_steward_state(
|
||||
self,
|
||||
response: StewardRuntimeDecisionResponse,
|
||||
request: StewardRuntimeDecisionRequest,
|
||||
) -> StewardRuntimeDecisionResponse:
|
||||
steward_state = request.runtime_state.get("steward_state")
|
||||
if not isinstance(steward_state, dict) or not steward_state:
|
||||
return response
|
||||
if response.next_action == "continue_selected_flow":
|
||||
flow_id = self._resolve_target_flow_id(response, steward_state)
|
||||
if flow_id:
|
||||
next_state = StewardFlowStateService().confirm_flow(steward_state, flow_id)
|
||||
return response.model_copy(update={"steward_state": next_state})
|
||||
return response.model_copy(update={"steward_state": steward_state})
|
||||
if response.next_action != "fill_current_slot" or not response.field_key:
|
||||
return response.model_copy(update={"steward_state": steward_state})
|
||||
|
||||
flow_id = self._resolve_target_flow_id(response, steward_state)
|
||||
if not flow_id:
|
||||
return response.model_copy(update={"steward_state": steward_state})
|
||||
current_flow = self._resolve_flow(steward_state, flow_id)
|
||||
remaining_missing_fields = [
|
||||
key
|
||||
for key in list(current_flow.get("missing_fields") or [])
|
||||
if str(key or "").strip() and str(key or "").strip() != response.field_key
|
||||
]
|
||||
next_state = StewardFlowStateService().merge_state(
|
||||
steward_state,
|
||||
StewardFlowStatePatch(
|
||||
active_flow=flow_id, # type: ignore[arg-type]
|
||||
flow_id=flow_id, # type: ignore[arg-type]
|
||||
intent=str(current_flow.get("intent") or "").strip(),
|
||||
status="collecting" if remaining_missing_fields else "ready_for_confirmation",
|
||||
fields={response.field_key: response.field_value},
|
||||
missing_fields=remaining_missing_fields,
|
||||
evidence=[
|
||||
{
|
||||
"source": "runtime_user_message",
|
||||
"field": response.field_key,
|
||||
"text": request.user_message,
|
||||
}
|
||||
],
|
||||
),
|
||||
)
|
||||
return response.model_copy(update={"steward_state": next_state})
|
||||
|
||||
@staticmethod
|
||||
def _resolve_target_flow_id(
|
||||
response: StewardRuntimeDecisionResponse,
|
||||
steward_state: dict[str, Any],
|
||||
) -> str:
|
||||
target = str(response.target_task_id or "").strip()
|
||||
if target in {"travel_application", "travel_reimbursement"}:
|
||||
return target
|
||||
active_flow = str(steward_state.get("active_flow") or "").strip()
|
||||
return active_flow if active_flow in {"travel_application", "travel_reimbursement"} else ""
|
||||
|
||||
@staticmethod
|
||||
def _resolve_flow(steward_state: dict[str, Any], flow_id: str) -> dict[str, Any]:
|
||||
flows = steward_state.get("flows") if isinstance(steward_state.get("flows"), dict) else {}
|
||||
flow = flows.get(flow_id) if isinstance(flows, dict) else {}
|
||||
return dict(flow) if isinstance(flow, dict) else {}
|
||||
|
||||
Reference in New Issue
Block a user