Files
X-Financial/server/tests/test_steward_planner.py
caoxiaozhu cce19e4c40 feat(steward): 拦截业务无关输入返回 off_topic 计划
- schemas/steward.py:StewardPlanResponse 新增 suggested_prompts 字段
- steward_planner.py:新增 STEWARD_BUSINESS_SIGNAL_KEYWORDS 与
  _is_business_irrelevant_input 守卫,在 build_plan 入口前置;
  新增 _build_off_topic_plan 构造 plan_status=off_topic 的引导计划
- steward_intent_agent.py:system prompt 追加业务无关约束
- test_steward_planner.py:覆盖 123/你好/纯标点走 off_topic,
  并验证正常业务输入不受守卫影响
2026-06-18 14:15:20 +08:00

616 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
from fastapi.testclient import TestClient
from app.main import create_app
from app.schemas.steward import StewardAttachmentInput, StewardPlanRequest
from app.services.steward_intent_agent import StewardIntentAgentResult
from app.services.steward_planner import StewardPlannerService
class FakeFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
assert "expense_type" in canonical_fields
assert base_date.isoformat() == "2026-06-04"
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "识别复合报销意图",
"content": "模型工具调用识别出 1 个报销任务,并关联本次上传的交通附件。",
}
],
"tasks": [
{
"task_type": "reimbursement",
"title": "费用报销 2026-06-03 交通",
"summary": "报销昨天客户现场沟通产生的交通费。",
"confidence": 0.91,
"ontology_fields": {
"occurred_date": "昨天",
"transport_type": "出租车",
"reason_value": "客户现场沟通",
"expense_type": "交通费",
"unregistered_field": "不能进入业务字段",
},
"missing_fields": ["amount", "transport_type"],
}
],
"attachment_groups": [
{
"target_task_index": 1,
"scene": "transport",
"scene_label": "交通费用",
"attachment_names": ["出租车票.png"],
"excluded_attachment_names": ["客户招待发票.jpg"],
"confidence": 0.86,
"rationale": "出租车票与交通报销任务匹配,招待发票不归入该任务。",
}
],
},
model_call_traces=[
{
"slot": "main",
"provider": "OpenAI Compatible",
"model": "gpt-test",
"attempt": 1,
"status": "succeeded",
}
],
)
class CountingFunctionCallingIntentAgent(FakeFunctionCallingIntentAgent):
def __init__(self) -> None:
self.calls = 0
def detect(self, request, *, base_date, canonical_fields):
self.calls += 1
return super().detect(request, base_date=base_date, canonical_fields=canonical_fields)
class CountingNoResultIntentAgent:
def __init__(self) -> None:
self.calls = 0
def detect(self, request, *, base_date, canonical_fields):
self.calls += 1
return None
class EmptyFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return None
class EntertainmentFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [],
"tasks": [
{
"task_type": "reimbursement",
"title": "业务招待费报销",
"summary": "报销昨天业务招待费。",
"confidence": 0.9,
"ontology_fields": {
"time_range": "昨天",
"expense_type": "业务招待费",
"reason": "业务招待",
},
"missing_fields": [],
}
],
"attachment_groups": [],
},
model_call_traces=[],
)
class ApplicationFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "识别出差申请",
"content": "模型识别到用户要发起北京出差申请,并且后续还有报销事项。",
}
],
"tasks": [
{
"task_type": "expense_application",
"title": "北京出差申请",
"summary": "明天前往北京出差3天支撑国网仿生产部署。",
"confidence": 0.94,
"ontology_fields": {
"time_range": "明天",
"location": "北京",
"expense_type": "差旅",
"reason": "支撑国网仿生产部署",
},
"missing_fields": [],
}
],
"attachment_groups": [],
},
model_call_traces=[],
)
class PendingFlowFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "flow_confirmation",
"title": "识别到出差事项但动作不明确",
"content": "用户提供了时间、地点和事由,但没有明确要补办申请还是发起报销。",
}
],
"pending_flow_confirmation": {
"status": "pending",
"source_message": request.message,
"reason": "缺少申请或报销动作词,需要用户确认流程方向。",
"candidate_flows": [
{
"flow_id": "travel_application",
"label": "补办出差申请",
"confidence": 0.52,
"reason": "这句话可以理解为补办出差申请。",
"ontology_fields": {
"time_range": "2月20日",
"location": "上海",
"expense_type": "差旅",
"reason": "辅助国网仿生产环境部署",
},
"missing_fields": ["transport_mode"],
},
{
"flow_id": "travel_reimbursement",
"label": "发起费用报销",
"confidence": 0.48,
"reason": "这句话也可能是在为已发生出差发起报销。",
"ontology_fields": {
"time_range": "2月20日",
"location": "上海",
"expense_type": "差旅",
"reason": "辅助国网仿生产环境部署",
},
"missing_fields": [],
},
],
},
"tasks": [],
"attachment_groups": [],
},
model_call_traces=[],
)
class AmbiguousApplicationFunctionCallingIntentAgent:
def detect(self, request, *, base_date, canonical_fields):
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "模型直接判定为申请",
"content": "模型误把无动作词的历史出差描述直接判定为申请。",
}
],
"tasks": [
{
"task_type": "expense_application",
"title": "上海出差申请",
"summary": "2月20-23日去上海出差辅助国网仿生产环境部署。",
"confidence": 0.9,
"ontology_fields": {
"time_range": "2月20日",
"location": "上海",
"expense_type": "差旅",
"reason": "辅助国网仿生产环境部署",
},
"missing_fields": ["transport_mode"],
}
],
"attachment_groups": [],
},
model_call_traces=[{"status": "succeeded"}],
)
def test_steward_planner_uses_llm_function_calling_plan_when_available() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u5ba2\u6237\u73b0\u573a\u6c9f\u901a\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
attachments=[
StewardAttachmentInput(name="出租车票.png"),
StewardAttachmentInput(name="客户招待发票.jpg"),
],
)
result = StewardPlannerService(intent_agent=FakeFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "llm_function_call"
assert result.model_call_traces[0]["status"] == "succeeded"
assert len(result.tasks) == 1
fields = result.tasks[0].ontology_fields
assert fields["time_range"] == "2026-06-03"
assert fields["transport_mode"] == "taxi"
assert fields["reason"] == "客户现场沟通"
assert fields["expense_type"] == "transport"
assert "occurred_date" not in fields
assert "transport_type" not in fields
assert "reason_value" not in fields
assert "unregistered_field" not in fields
assert result.tasks[0].missing_fields == ["amount"]
assert result.attachment_groups[0].attachment_names == ["出租车票.png"]
assert result.attachment_groups[0].excluded_attachment_names == ["客户招待发票.jpg"]
assert result.thinking_events[0].stage == "llm_function_call"
def test_steward_planner_normalizes_llm_business_entertainment_expense_type() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u62a5\u9500\u6628\u5929\u4e1a\u52a1\u62db\u5f85\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=EntertainmentFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "llm_function_call"
assert result.tasks[0].ontology_fields["expense_type"] == "entertainment"
assert result.tasks[0].ontology_fields["time_range"] == "2026-06-03"
def test_steward_planner_enforces_application_transport_gap_after_function_calling() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u660e\u5929\u51fa\u5dee\u5317\u4eac\u0033\u5929\uff0c\u652f\u6491\u56fd\u7f51\u4eff\u751f\u4ea7\u90e8\u7f72\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=ApplicationFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "llm_function_call"
assert result.tasks[0].missing_fields == ["transport_mode"]
gap_events = [event for event in result.thinking_events if event.stage == "business_gap_check"]
assert gap_events
assert "没有说明出行方式" in gap_events[0].content
assert "火车、飞机或轮船" in gap_events[0].content
def test_steward_planner_returns_pending_flow_confirmation_from_llm() -> None:
payload = StewardPlanRequest(
message="2月20-23日去上海出差辅助国网仿生产环境部署",
client_now_iso="2026-06-15T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=PendingFlowFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.plan_status == "needs_flow_confirmation"
assert result.pending_flow_confirmation.status == "pending"
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
assert result.candidate_flows[0].ontology_fields["time_range"] == "2026-02-20"
assert result.candidate_flows[0].ontology_fields["location"] == "上海"
assert "申请" in result.summary and "报销" in result.summary
def test_steward_planner_skips_llm_for_single_ambiguous_travel_flow() -> None:
payload = StewardPlanRequest(
message="\u0032\u6708\u0032\u0030-\u0032\u0033\u65e5\u53bb\u4e0a\u6d77\u51fa\u5dee\u8f85\u52a9\u56fd\u7f51\u4eff\u751f\u4ea7\u73af\u5883\u90e8\u7f72",
client_now_iso="2026-06-15T09:30:00+08:00",
)
intent_agent = CountingNoResultIntentAgent()
result = StewardPlannerService(intent_agent=intent_agent).build_plan(payload)
assert intent_agent.calls == 0
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.plan_status == "needs_flow_confirmation"
assert result.model_call_traces == []
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
def test_steward_planner_uses_llm_for_multi_financial_demands() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
intent_agent = CountingFunctionCallingIntentAgent()
result = StewardPlannerService(intent_agent=intent_agent).build_plan(payload)
assert intent_agent.calls == 1
assert result.planning_source == "llm_function_call"
assert result.model_call_traces[0]["status"] == "succeeded"
def test_steward_planner_overrides_llm_direct_application_for_ambiguous_travel_flow() -> None:
payload = StewardPlanRequest(
message="2月20-23日去上海出差辅助国网仿生产环境部署",
client_now_iso="2026-06-15T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=AmbiguousApplicationFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.plan_status == "needs_flow_confirmation"
assert result.tasks == []
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
def test_steward_planner_falls_back_to_rules_when_function_calling_is_unavailable() -> None:
payload = StewardPlanRequest(
message="\u6211\u60f3\u7533\u8bf7\u0037\u6708\u0032\u65e5\u53bb\u5317\u4eac\u51fa\u5dee\uff0c\u5e76\u4e14\u6211\u8981\u62a5\u9500\u6628\u5929\u7684\u4ea4\u901a\u8d39",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=EmptyFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert [task.task_type for task in result.tasks] == ["expense_application", "reimbursement"]
assert result.tasks[0].ontology_fields["time_range"] == "2026-07-02"
assert result.tasks[1].ontology_fields["time_range"] == "2026-06-03"
assert result.thinking_events[0].stage == "rule_fallback"
def test_steward_planner_rule_fallback_confirms_ambiguous_travel_flow() -> None:
payload = StewardPlanRequest(
message="2月20-23日去上海出差辅助国网仿生产环境部署",
client_now_iso="2026-06-15T09:30:00+08:00",
)
result = StewardPlannerService(intent_agent=EmptyFunctionCallingIntentAgent()).build_plan(payload)
assert result.planning_source == "rule_fallback"
assert result.next_action == "confirm_flow"
assert result.pending_flow_confirmation.status == "pending"
assert [item.flow_id for item in result.candidate_flows] == [
"travel_application",
"travel_reimbursement",
]
assert result.tasks == []
assert result.confirmation_groups == []
def test_steward_planner_splits_application_and_reimbursement_tasks() -> None:
payload = StewardPlanRequest(
message=(
"我想要申请7月2日去北京出差辅助北京供电局的税务审核任务"
"并且我要报销昨天的交通费还需要报销6月3日出差去上海的费用"
),
user_id="u001",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert len(result.tasks) == 3
assert [task.task_type for task in result.tasks] == [
"expense_application",
"reimbursement",
"reimbursement",
]
assert result.tasks[0].assigned_agent == "application_assistant"
assert result.tasks[0].ontology_fields["time_range"] == "2026-07-02"
assert result.tasks[0].ontology_fields["location"] == "北京"
assert result.tasks[0].ontology_fields["reason"] == "辅助北京供电局的税务审核任务"
assert result.tasks[1].ontology_fields["time_range"] == "2026-06-03"
assert result.tasks[1].ontology_fields["expense_type"] == "transport"
assert result.tasks[2].ontology_fields["time_range"] == "2026-06-03"
assert result.tasks[2].ontology_fields["location"] == "上海"
assert result.tasks[2].ontology_fields["expense_type"] == "travel"
assert all(action.status == "pending" for action in result.confirmation_groups)
def test_steward_planner_treats_future_travel_without_apply_word_as_application() -> None:
payload = StewardPlanRequest(
message="明天出差北京3天支撑国网仿生产部署并且报销昨天业务招待费",
user_id="u001",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert [task.task_type for task in result.tasks] == [
"expense_application",
"reimbursement",
]
assert result.tasks[0].assigned_agent == "application_assistant"
assert result.tasks[0].ontology_fields["time_range"] == "2026-06-05"
assert result.tasks[0].ontology_fields["location"] == "北京"
assert result.tasks[0].ontology_fields["expense_type"] == "travel"
assert result.tasks[0].ontology_fields["reason"] == "支撑国网仿生产部署"
assert result.tasks[0].missing_fields == ["transport_mode"]
gap_events = [event for event in result.thinking_events if event.stage == "business_gap_check"]
assert gap_events
assert "没有说明出行方式" in gap_events[0].content
assert result.tasks[1].assigned_agent == "reimbursement_assistant"
assert result.tasks[1].ontology_fields["time_range"] == "2026-06-03"
assert result.tasks[1].ontology_fields["expense_type"] == "entertainment"
def test_steward_planner_outputs_only_canonical_ontology_fields() -> None:
payload = StewardPlanRequest(
message="我要报销昨天的交通费",
client_now_iso="2026-06-04T09:30:00+08:00",
context_json={
"review_form_values": {
"occurred_date": "2026-06-03",
"transport_type": "taxi",
"reason_value": "客户现场沟通",
}
},
)
result = StewardPlannerService().build_plan(payload)
fields = result.tasks[0].ontology_fields
assert fields["time_range"] == "2026-06-03"
assert fields["transport_mode"] == "taxi"
assert fields["reason"] == "客户现场沟通"
assert "occurred_date" not in fields
assert "transport_type" not in fields
assert "reason_value" not in fields
def test_steward_planner_builds_travel_attachment_group_with_exclusions() -> None:
payload = StewardPlanRequest(
message="还需要报销6月3日出差去上海的费用",
client_now_iso="2026-06-04T09:30:00+08:00",
attachments=[
StewardAttachmentInput(name="上海高铁票.jpg"),
StewardAttachmentInput(name="上海酒店发票.pdf"),
StewardAttachmentInput(name="出租车票.png"),
StewardAttachmentInput(name="客户招待发票.jpg"),
],
)
result = StewardPlannerService().build_plan(payload)
assert len(result.attachment_groups) == 1
group = result.attachment_groups[0]
assert group.scene == "travel"
assert group.attachment_names == ["上海高铁票.jpg", "上海酒店发票.pdf", "出租车票.png"]
assert group.excluded_attachment_names == ["客户招待发票.jpg"]
assert group.confirmation_required is True
attachment_actions = [
action for action in result.confirmation_groups if action.action_type == "confirm_attachment_group"
]
assert len(attachment_actions) == 1
def test_steward_stream_endpoint_emits_thinking_before_plan() -> None:
client = TestClient(create_app())
with client.stream(
"POST",
"/api/v1/steward/plans/stream",
json={
"message": "我要报销昨天的交通费",
"client_now_iso": "2026-06-04T09:30:00+08:00",
},
) as response:
assert response.status_code == 200
events = [
json.loads(line.decode("utf-8") if isinstance(line, bytes) else line)
for line in response.iter_lines()
if line
]
assert [event["event"] for event in events][:2] == ["thinking", "thinking"]
assert events[0]["data"]["stage"] == "stream_start"
assert events[-1]["event"] == "plan"
assert events[-1]["data"]["tasks"][0]["ontology_fields"]["time_range"] == "2026-06-03"
def test_steward_plan_endpoint_persists_application_and_reimbursement_state() -> None:
client = TestClient(create_app())
response = client.post(
"/api/v1/steward/plans",
json={
"message": "我想申请7月2日去北京出差并且我要报销昨天的交通费",
"user_id": "u-steward-state",
"client_now_iso": "2026-06-04T09:30:00+08:00",
"context_json": {"session_type": "steward", "entry_source": "personal_workbench"},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["conversation_id"].startswith("conv_")
state = payload["steward_state"]
assert state["active_flow"] == "travel_reimbursement"
assert state["flows"]["travel_application"]["fields"]["location"] == "北京"
assert state["flows"]["travel_application"]["fields"]["time_range"] == "2026-07-02"
assert state["flows"]["travel_reimbursement"]["fields"]["time_range"] == "2026-06-03"
assert state["flows"]["travel_reimbursement"]["fields"]["expense_type"] == "transport"
assert all("invented_field" not in flow["fields"] for flow in state["flows"].values())
def test_steward_planner_returns_off_topic_for_business_irrelevant_input() -> None:
payload = StewardPlanRequest(
message="123",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert result.plan_status == "off_topic"
assert result.next_action == "none"
assert result.tasks == []
assert result.attachment_groups == []
assert result.confirmation_groups == []
assert result.candidate_flows == []
assert result.planning_source == "rule_fallback"
assert len(result.suggested_prompts) == 3
assert result.thinking_events[0].stage == "off_topic"
def test_steward_planner_returns_off_topic_for_pure_greeting() -> None:
payload = StewardPlanRequest(
message="你好",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert result.plan_status == "off_topic"
assert result.next_action == "none"
assert result.tasks == []
assert result.candidate_flows == []
assert result.planning_source == "rule_fallback"
assert len(result.suggested_prompts) == 3
assert result.thinking_events[0].stage == "off_topic"
def test_steward_planner_returns_off_topic_for_pure_punctuation() -> None:
payload = StewardPlanRequest(
message="??? !!!",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert result.plan_status == "off_topic"
assert result.next_action == "none"
assert result.tasks == []
assert result.candidate_flows == []
assert result.planning_source == "rule_fallback"
assert len(result.suggested_prompts) == 3
assert result.thinking_events[0].stage == "off_topic"
def test_steward_planner_preserves_normal_business_flow_after_guard() -> None:
payload = StewardPlanRequest(
message="我要报销昨天的交通费",
client_now_iso="2026-06-04T09:30:00+08:00",
)
result = StewardPlannerService().build_plan(payload)
assert result.plan_status != "off_topic"
assert len(result.tasks) >= 1
assert [task.task_type for task in result.tasks] == ["reimbursement"]