Files
X-Financial/server/tests/test_steward_planner.py

550 lines
23 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
from fastapi.testclient import TestClient
from app.main import create_app
from app.schemas.steward import StewardAttachmentInput, StewardPlanRequest
from app.services.steward_intent_agent import StewardIntentAgentResult
from app.services.steward_planner import StewardPlannerService
class FakeFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
assert "expense_type" in canonical_fields
assert base_date.isoformat() == "2026-06-04"
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "识别复合报销意图",
"content": "模型工具调用识别出 1 个报销任务,并关联本次上传的交通附件。",
}
],
"tasks": [
{
"task_type": "reimbursement",
"title": "费用报销 2026-06-03 交通",
"summary": "报销昨天客户现场沟通产生的交通费。",
"confidence": 0.91,
"ontology_fields": {
"occurred_date": "昨天",
"transport_type": "出租车",
"reason_value": "客户现场沟通",
"expense_type": "交通费",
"unregistered_field": "不能进入业务字段",
},
"missing_fields": ["amount", "transport_type"],
}
],
"attachment_groups": [
{
"target_task_index": 1,
"scene": "transport",
"scene_label": "交通费用",
"attachment_names": ["出租车票.png"],
"excluded_attachment_names": ["客户招待发票.jpg"],
"confidence": 0.86,
"rationale": "出租车票与交通报销任务匹配,招待发票不归入该任务。",
}
],
},
model_call_traces=[
{
"slot": "main",
"provider": "OpenAI Compatible",
"model": "gpt-test",
"attempt": 1,
"status": "succeeded",
}
],
)
class CountingFunctionCallingIntentAgent(FakeFunctionCallingIntentAgent):
def __init__(self) -> None:
self.calls = 0
def detect(self, request, *, base_date, canonical_fields):
self.calls += 1
return super().detect(request, base_date=base_date, canonical_fields=canonical_fields)
class CountingNoResultIntentAgent:
def __init__(self) -> None:
self.calls = 0
def detect(self, request, *, base_date, canonical_fields):
self.calls += 1
return None
class EmptyFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return None
class EntertainmentFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [],
"tasks": [
{
"task_type": "reimbursement",
"title": "业务招待费报销",
"summary": "报销昨天业务招待费。",
"confidence": 0.9,
"ontology_fields": {
"time_range": "昨天",
"expense_type": "业务招待费",
"reason": "业务招待",
},
"missing_fields": [],
}
],
"attachment_groups": [],
},
model_call_traces=[],
)
class ApplicationFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "识别出差申请",
"content": "模型识别到用户要发起北京出差申请,并且后续还有报销事项。",
}
],
"tasks": [
{
"task_type": "expense_application",
"title": "北京出差申请",
"summary": "明天前往北京出差3天支撑国网仿生产部署。",
"confidence": 0.94,
"ontology_fields": {
"time_range": "明天",
"location": "北京",
"expense_type": "差旅",
"reason": "支撑国网仿生产部署",
},
"missing_fields": [],
}
],
"attachment_groups": [],
},
model_call_traces=[],
)
class PendingFlowFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "flow_confirmation",
"title": "识别到出差事项但动作不明确",
"content": "用户提供了时间、地点和事由,但没有明确要补办申请还是发起报销。",
}
],
"pending_flow_confirmation": {
"status": "pending",
"source_message": request.message,
"reason": "缺少申请或报销动作词,需要用户确认流程方向。",
"candidate_flows": [
{
"flow_id": "travel_application",
"label": "补办出差申请",
"confidence": 0.52,
"reason": "这句话可以理解为补办出差申请。",
"ontology_fields": {
"time_range": "2月20日",
"location": "上海",
"expense_type": "差旅",
"reason": "辅助国网仿生产环境部署",
},
"missing_fields": ["transport_mode"],
},
{
"flow_id": "travel_reimbursement",
"label": "发起费用报销",
"confidence": 0.48,
"reason": "这句话也可能是在为已发生出差发起报销。",
"ontology_fields": {
"time_range": "2月20日",
"location": "上海",
"expense_type": "差旅",
"reason": "辅助国网仿生产环境部署",
},
"missing_fields": [],
},
],
},
"tasks": [],
"attachment_groups": [],
},
model_call_traces=[],
)
class AmbiguousApplicationFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "模型直接判定为申请",
"content": "模型误把无动作词的历史出差描述直接判定为申请。",
}
],
"tasks": [
{
"task_type": "expense_application",
"title": "上海出差申请",
"summary": "2月20-23日去上海出差辅助国网仿生产环境部署。",
"confidence": 0.9,
"ontology_fields": {
"time_range": "2月20日",
"location": "上海",
"expense_type": "差旅",
"reason": "辅助国网仿生产环境部署",
},
"missing_fields": ["transport_mode"],
}
],
"attachment_groups": [],
},
model_call_traces=[{"status": "succeeded"}],
)
def test_steward_planner_uses_llm_function_calling_plan_when_available() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u5ba2\u6237\u73b0\u573a\u6c9f\u901a\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
attachments=[
StewardAttachmentInput(name="出租车票.png"),
StewardAttachmentInput(name="客户招待发票.jpg"),
],
)
result = StewardPlannerService(intent_agent=FakeFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "llm_function_call"
assert result.model_call_traces[0]["status"] == "succeeded"
assert len(result.tasks) == 1
fields = result.tasks[0].ontology_fields
assert fields["time_range"] == "2026-06-03"
assert fields["transport_mode"] == "taxi"
assert fields["reason"] == "客户现场沟通"
assert fields["expense_type"] == "transport"
assert "occurred_date" not in fields
assert "transport_type" not in fields
assert "reason_value" not in fields
assert "unregistered_field" not in fields
assert result.tasks[0].missing_fields == ["amount"]
assert result.attachment_groups[0].attachment_names == ["出租车票.png"]
assert result.attachment_groups[0].excluded_attachment_names == ["客户招待发票.jpg"]
assert result.thinking_events[0].stage == "llm_function_call"
def test_steward_planner_normalizes_llm_business_entertainment_expense_type() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u62a5\u9500\u6628\u5929\u4e1a\u52a1\u62db\u5f85\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=EntertainmentFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "llm_function_call"
assert result.tasks[0].ontology_fields["expense_type"] == "entertainment"
assert result.tasks[0].ontology_fields["time_range"] == "2026-06-03"
def test_steward_planner_enforces_application_transport_gap_after_function_calling() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u660e\u5929\u51fa\u5dee\u5317\u4eac\u0033\u5929\uff0c\u652f\u6491\u56fd\u7f51\u4eff\u751f\u4ea7\u90e8\u7f72\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=ApplicationFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "llm_function_call"
assert result.tasks[0].missing_fields == ["transport_mode"]
gap_events = [event for event in result.thinking_events if event.stage == "business_gap_check"]
assert gap_events
assert "没有说明出行方式" in gap_events[0].content
assert "火车、飞机或轮船" in gap_events[0].content
def test_steward_planner_returns_pending_flow_confirmation_from_llm() -> None:
payload = StewardPlanRequest(
message="2月20-23日去上海出差辅助国网仿生产环境部署",
client_now_iso="2026-06-15T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=PendingFlowFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.plan_status == "needs_flow_confirmation"
assert result.pending_flow_confirmation.status == "pending"
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
assert result.candidate_flows[0].ontology_fields["time_range"] == "2026-02-20"
assert result.candidate_flows[0].ontology_fields["location"] == "上海"
assert "申请" in result.summary and "报销" in result.summary
def test_steward_planner_skips_llm_for_single_ambiguous_travel_flow() -> None:
payload = StewardPlanRequest(
message="\u0032\u6708\u0032\u0030-\u0032\u0033\u65e5\u53bb\u4e0a\u6d77\u51fa\u5dee\u8f85\u52a9\u56fd\u7f51\u4eff\u751f\u4ea7\u73af\u5883\u90e8\u7f72",
client_now_iso="2026-06-15T09:30:00+08:00",
)
intent_agent = CountingNoResultIntentAgent()
result = StewardPlannerService(intent_agent=intent_agent).build_plan(payload)
assert intent_agent.calls == 0
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.plan_status == "needs_flow_confirmation"
assert result.model_call_traces == []
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
def test_steward_planner_uses_llm_for_multi_financial_demands() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
intent_agent = CountingFunctionCallingIntentAgent()
result = StewardPlannerService(intent_agent=intent_agent).build_plan(payload)
assert intent_agent.calls == 1
assert result.planning_source == "llm_function_call"
assert result.model_call_traces[0]["status"] == "succeeded"
def test_steward_planner_overrides_llm_direct_application_for_ambiguous_travel_flow() -> None:
payload = StewardPlanRequest(
message="2月20-23日去上海出差辅助国网仿生产环境部署",
client_now_iso="2026-06-15T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=AmbiguousApplicationFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.plan_status == "needs_flow_confirmation"
assert result.tasks == []
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
def test_steward_planner_falls_back_to_rules_when_function_calling_is_unavailable() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=EmptyFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert [task.task_type for task in result.tasks] == ["expense_application", "reimbursement"]
assert result.tasks[0].ontology_fields["time_range"] == "2026-07-02"
assert result.tasks[1].ontology_fields["time_range"] == "2026-06-03"
assert result.thinking_events[0].stage == "rule_fallback"
def test_steward_planner_rule_fallback_confirms_ambiguous_travel_flow() -> None:
payload = StewardPlanRequest(
message="2月20-23日去上海出差辅助国网仿生产环境部署",
client_now_iso="2026-06-15T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=EmptyFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.pending_flow_confirmation.status == "pending"
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
assert result.tasks == []
assert result.confirmation_groups == []
def test_steward_planner_splits_application_and_reimbursement_tasks() -> None:
payload = StewardPlanRequest(
message=(
"我想要申请7月2日去北京出差辅助北京供电局的税务审核任务"
"并且我要报销昨天的交通费还需要报销6月3日出差去上海的费用"
),
user_id="u001",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert len(result.tasks) == 3
assert [task.task_type for task in result.tasks] == [
"expense_application",
"reimbursement",
"reimbursement",
]
assert result.tasks[0].assigned_agent == "application_assistant"
assert result.tasks[0].ontology_fields["time_range"] == "2026-07-02"
assert result.tasks[0].ontology_fields["location"] == "北京"
assert result.tasks[0].ontology_fields["reason"] == "辅助北京供电局的税务审核任务"
assert result.tasks[1].ontology_fields["time_range"] == "2026-06-03"
assert result.tasks[1].ontology_fields["expense_type"] == "transport"
assert result.tasks[2].ontology_fields["time_range"] == "2026-06-03"
assert result.tasks[2].ontology_fields["location"] == "上海"
assert result.tasks[2].ontology_fields["expense_type"] == "travel"
assert all(action.status == "pending" for action in result.confirmation_groups)
def test_steward_planner_treats_future_travel_without_apply_word_as_application() -> None:
payload = StewardPlanRequest(
message="明天出差北京3天支撑国网仿生产部署并且报销昨天业务招待费",
user_id="u001",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert [task.task_type for task in result.tasks] == [
"expense_application",
"reimbursement",
]
assert result.tasks[0].assigned_agent == "application_assistant"
assert result.tasks[0].ontology_fields["time_range"] == "2026-06-05"
assert result.tasks[0].ontology_fields["location"] == "北京"
assert result.tasks[0].ontology_fields["expense_type"] == "travel"
assert result.tasks[0].ontology_fields["reason"] == "支撑国网仿生产部署"
assert result.tasks[0].missing_fields == ["transport_mode"]
gap_events = [event for event in result.thinking_events if event.stage == "business_gap_check"]
assert gap_events
assert "没有说明出行方式" in gap_events[0].content
assert result.tasks[1].assigned_agent == "reimbursement_assistant"
assert result.tasks[1].ontology_fields["time_range"] == "2026-06-03"
assert result.tasks[1].ontology_fields["expense_type"] == "entertainment"
def test_steward_planner_outputs_only_canonical_ontology_fields() -> None:
payload = StewardPlanRequest(
message="我要报销昨天的交通费",
client_now_iso="2026-06-04T09:30:00+08:00",
context_json={
"review_form_values": {
"occurred_date": "2026-06-03",
"transport_type": "taxi",
"reason_value": "客户现场沟通",
}
},
)
result = StewardPlannerService().build_plan(payload)
fields = result.tasks[0].ontology_fields
assert fields["time_range"] == "2026-06-03"
assert fields["transport_mode"] == "taxi"
assert fields["reason"] == "客户现场沟通"
assert "occurred_date" not in fields
assert "transport_type" not in fields
assert "reason_value" not in fields
def test_steward_planner_builds_travel_attachment_group_with_exclusions() -> None:
payload = StewardPlanRequest(
message="还需要报销6月3日出差去上海的费用",
client_now_iso="2026-06-04T09:30:00+08:00",
attachments=[
StewardAttachmentInput(name="上海高铁票.jpg"),
StewardAttachmentInput(name="上海酒店发票.pdf"),
StewardAttachmentInput(name="出租车票.png"),
StewardAttachmentInput(name="客户招待发票.jpg"),
],
)
result = StewardPlannerService().build_plan(payload)
assert len(result.attachment_groups) == 1
group = result.attachment_groups[0]
assert group.scene == "travel"
assert group.attachment_names == ["上海高铁票.jpg", "上海酒店发票.pdf", "出租车票.png"]
assert group.excluded_attachment_names == ["客户招待发票.jpg"]
assert group.confirmation_required is True
attachment_actions = [
action for action in result.confirmation_groups if action.action_type == "confirm_attachment_group"
]
assert len(attachment_actions) == 1
def test_steward_stream_endpoint_emits_thinking_before_plan() -> None:
client = TestClient(create_app())
with client.stream(
"POST",
"/api/v1/steward/plans/stream",
json={
"message": "我要报销昨天的交通费",
"client_now_iso": "2026-06-04T09:30:00+08:00",
},
) as response:
assert response.status_code == 200
events = [
json.loads(line.decode("utf-8") if isinstance(line, bytes) else line)
for line in response.iter_lines()
if line
]
assert [event["event"] for event in events][:2] == ["thinking", "thinking"]
assert events[0]["data"]["stage"] == "stream_start"
assert events[-1]["event"] == "plan"
assert events[-1]["data"]["tasks"][0]["ontology_fields"]["time_range"] == "2026-06-03"
def test_steward_plan_endpoint_persists_application_and_reimbursement_state() -> None:
client = TestClient(create_app())
response = client.post(
"/api/v1/steward/plans",
json={
"message": "我想申请7月2日去北京出差并且我要报销昨天的交通费",
"user_id": "u-steward-state",
"client_now_iso": "2026-06-04T09:30:00+08:00",
"context_json": {"session_type": "steward", "entry_source": "personal_workbench"},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["conversation_id"].startswith("conv_")
state = payload["steward_state"]
assert state["active_flow"] == "travel_reimbursement"
assert state["flows"]["travel_application"]["fields"]["location"] == "北京"
assert state["flows"]["travel_application"]["fields"]["time_range"] == "2026-07-02"
assert state["flows"]["travel_reimbursement"]["fields"]["time_range"] == "2026-06-03"
assert state["flows"]["travel_reimbursement"]["fields"]["expense_type"] == "transport"
assert all("invented_field" not in flow["fields"] for flow in state["flows"].values())