feat(server): 更新用户代理服务架构,增强用户行为追踪和会话管理功能,包含schema、service和单元测试

This commit is contained in:
caoxiaozhu
2026-05-14 15:42:33 +00:00
parent fad583ee7c
commit ad16358e71
3 changed files with 664 additions and 41 deletions

View File

@@ -117,6 +117,8 @@ class UserAgentReviewDocumentCard(BaseModel):
scene_label: str = Field(default="", description="面向用户展示的场景标签。")
summary: str = Field(default="", description="逐票据摘要。")
avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="OCR 平均得分。")
preview_kind: str = Field(default="", description="票据预览类型,例如 image。")
preview_data_url: str = Field(default="", description="票据预览图片 data URL。")
warnings: list[str] = Field(default_factory=list, description="该票据的识别提示。")
fields: list[UserAgentReviewDocumentField] = Field(
default_factory=list,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import re
from datetime import UTC, datetime, timedelta
from decimal import Decimal, InvalidOperation
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
@@ -118,6 +119,11 @@ SLOT_LABELS = {
DATE_TEXT_PATTERN = re.compile(r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?)")
AMOUNT_TEXT_PATTERN = re.compile(r"(\d+(?:\.\d+)?)\s*(?:元|万元|万)")
DOCUMENT_AMOUNT_PATTERN = re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
)
DOCUMENT_CURRENCY_AMOUNT_PATTERN = re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)")
SOURCE_LABELS = {
"user_text": "用户描述",
@@ -130,6 +136,36 @@ SOURCE_LABELS = {
"system": "系统判断",
}
SCENE_REQUIRED_SLOT_KEYS = {
"hotel": {"merchant_name"},
"meeting": {"location"},
"entertainment": {"location", "customer_name", "participants"},
}
INFERRED_REASON_LABELS = {
"travel": "出差行程",
"hotel": "住宿报销",
"transport": "交通出行",
"meal": "餐饮用餐",
"meeting": "会务活动",
"entertainment": "客户接待",
"office": "办公采购",
"training": "培训学习",
"communication": "通讯使用",
"welfare": "员工福利",
"other": "其他费用",
}
SYSTEM_GENERATED_REASON_PREFIXES = (
"我上传了",
"请按当前已识别信息",
"请把当前上传的票据",
"请基于当前上传的多张票据",
"我已核对右侧识别结果",
"请同步修正逐票据识别结果",
"我已修改识别信息",
"查看报销草稿",
"请解释一下当前这笔报销的合规风险和待补充项",
)
class UserAgentService:
def __init__(self, db: Session) -> None:
@@ -736,10 +772,15 @@ class UserAgentService:
document_cards=document_cards,
claim_groups=claim_groups,
)
can_proceed = self._can_proceed_review(
payload,
missing_slot_keys=missing_slot_keys,
claim_groups=claim_groups,
association_choice_pending = self._is_review_association_choice_pending(payload)
can_proceed = (
False
if association_choice_pending
else self._can_proceed_review(
payload,
missing_slot_keys=missing_slot_keys,
claim_groups=claim_groups,
)
)
confirmation_actions = self._build_review_confirmation_actions(
payload,
@@ -762,6 +803,7 @@ class UserAgentService:
slot_cards=slot_cards,
risk_briefs=risk_briefs,
can_proceed=can_proceed,
document_cards=document_cards,
)
return UserAgentReviewPayload(
@@ -798,7 +840,10 @@ class UserAgentService:
ocr_documents=ocr_documents,
)
merchant_slot = self._build_merchant_slot(payload, ocr_documents=ocr_documents)
reason_slot = self._build_reason_slot(payload)
reason_slot = self._build_reason_slot(
payload,
claim_groups=claim_groups,
)
attachment_slot = self._build_attachment_slot(payload)
required_keys = self._resolve_required_review_keys(
payload,
@@ -922,6 +967,8 @@ class UserAgentService:
),
summary=str(item.get("summary") or item.get("text") or "").strip(),
avg_score=float(item.get("avg_score") or 0.0),
preview_kind=str(item.get("preview_kind") or "").strip(),
preview_data_url=str(item.get("preview_data_url") or "").strip(),
warnings=[str(warning) for warning in item.get("warnings", []) if str(warning).strip()],
fields=[
UserAgentReviewDocumentField(
@@ -950,14 +997,22 @@ class UserAgentService:
{
"document_indexes": [],
"amount_total": 0.0,
"expense_type": group_code,
"scene_label": GROUP_SCENE_LABELS.get(group_code, "其他费用"),
"expense_type": str(card.suggested_expense_type or group_code).strip() or group_code,
"scene_label": GROUP_SCENE_LABELS.get(
str(card.suggested_expense_type or group_code).strip() or group_code,
GROUP_SCENE_LABELS.get(group_code, "其他费用"),
),
"reasons": [],
},
)
bucket["document_indexes"].append(card.index)
bucket["amount_total"] = float(bucket["amount_total"]) + self._extract_amount_from_card(card)
bucket["reasons"].append(f"{card.filename} 识别为 {card.scene_label}")
current_expense_type = str(bucket["expense_type"] or "").strip()
current_card_type = str(card.suggested_expense_type or "").strip()
if current_expense_type and current_card_type and current_expense_type != current_card_type:
bucket["expense_type"] = group_code
bucket["scene_label"] = GROUP_SCENE_LABELS.get(group_code, "其他费用")
if not groups:
expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "other")
@@ -1080,6 +1135,40 @@ class UserAgentService:
claim_groups: list[UserAgentReviewClaimGroup],
draft_payload: UserAgentDraftPayload | None,
) -> list[UserAgentReviewAction]:
if self._is_review_association_choice_pending(payload):
claim_no = str(payload.tool_payload.get("association_candidate_claim_no") or "").strip()
link_label = f"关联到草稿 {claim_no}" if claim_no else "关联到现有草稿"
return [
UserAgentReviewAction(
label="取消",
action_type="cancel_review",
description="放弃当前识别结果,并退出本次核对流程。",
emphasis="secondary",
),
UserAgentReviewAction(
label="修改识别信息",
action_type="edit_review",
description="打开结构化模板,按已识别字段逐项修改。",
emphasis="secondary",
),
UserAgentReviewAction(
label=link_label,
action_type="link_to_existing_draft",
description=(
f"把本次上传票据并入现有草稿 {claim_no}"
if claim_no
else "把本次上传票据并入现有草稿。"
),
emphasis="primary",
),
UserAgentReviewAction(
label="单独建立报销单",
action_type="create_new_claim_from_documents",
description="基于当前上传的多张票据,新建一张独立的报销草稿。",
emphasis="secondary",
),
]
primary_action = UserAgentReviewAction(
label="继续下一步" if can_proceed else "保存为草稿",
action_type="next_step" if can_proceed else "save_draft",
@@ -1171,6 +1260,22 @@ class UserAgentService:
"后续您可以继续补充缺失项,或修改识别结果后再继续提交。"
)
return "已按您当前确认的信息保存为草稿。后续您可以继续补充缺失项,或修改识别结果后再继续提交。"
if review_action == "link_to_existing_draft":
document_count = self._resolve_review_document_count(payload)
if draft_payload is not None and draft_payload.claim_no:
return (
f"已将本次上传的 {document_count} 张票据关联到草稿 {draft_payload.claim_no}"
"您可以继续补充识别字段,确认无误后再提交审批。"
)
return "已将本次上传的票据关联到现有草稿。您可以继续补充识别字段,确认无误后再提交审批。"
if review_action == "create_new_claim_from_documents":
document_count = self._resolve_review_document_count(payload)
if draft_payload is not None and draft_payload.claim_no:
return (
f"已按当前上传的 {document_count} 张票据新建报销草稿 {draft_payload.claim_no}"
"您可以继续补充识别字段,确认无误后再提交审批。"
)
return "已按当前上传票据新建报销草稿。您可以继续补充识别字段,确认无误后再提交审批。"
if review_action == "next_step":
if draft_payload is not None and draft_payload.status == "submitted":
stage_text = draft_payload.approval_stage or "审批中"
@@ -1195,7 +1300,21 @@ class UserAgentService:
slot_cards: list[UserAgentReviewSlotCard],
risk_briefs: list[UserAgentReviewRiskBrief],
can_proceed: bool,
document_cards: list[UserAgentReviewDocumentCard],
) -> str:
if self._is_review_association_choice_pending(payload):
claim_no = str(payload.tool_payload.get("association_candidate_claim_no") or "").strip()
document_count = len(document_cards) or self._resolve_review_document_count(payload)
if claim_no:
return (
f"已识别出本次上传的 {document_count} 张票据。"
f"系统检测到你已有草稿 {claim_no},请选择关联到该草稿,或单独建立一张新的报销单。"
)
return (
f"已识别出本次上传的 {document_count} 张票据。"
"系统检测到你已有可用草稿,请先选择关联到现有草稿,或单独建立一张新的报销单。"
)
review_payload = UserAgentReviewPayload(
intent_summary="",
body_message="",
@@ -1423,6 +1542,22 @@ class UserAgentService:
return cleaned[:300]
return ""
@staticmethod
def _looks_like_system_generated_reason_message(message: str) -> bool:
cleaned = str(message or "").strip()
if not cleaned:
return False
compact = re.sub(r"\s+", "", cleaned)
return compact.startswith(SYSTEM_GENERATED_REASON_PREFIXES)
def _resolve_reason_source_text(self, payload: UserAgentRequest) -> str:
explicit_text = payload.context_json.get("user_input_text")
if isinstance(explicit_text, str):
return explicit_text.strip()
if self._looks_like_system_generated_reason_message(payload.message):
return ""
return str(payload.message or "").strip()
@classmethod
def _resolve_reason_text(cls, message: str) -> str:
reason = cls._extract_message_reason(message)
@@ -1553,13 +1688,58 @@ class UserAgentService:
documents = payload.context_json.get("ocr_documents")
if not isinstance(documents, list):
return []
overrides = payload.context_json.get("review_document_form_values")
override_map: dict[tuple[int, str], dict[str, object]] = {}
if isinstance(overrides, list):
for item in overrides:
if not isinstance(item, dict):
continue
filename = str(item.get("filename") or "").strip()
index = int(item.get("index") or 0)
if not filename and index <= 0:
continue
override_map[(index, filename)] = item
normalized: list[dict[str, object]] = []
for item in documents[:8]:
for index, item in enumerate(documents[:8], start=1):
if not isinstance(item, dict):
continue
normalized.append(item)
normalized_item = dict(item)
override = override_map.get((index, str(normalized_item.get("filename") or "").strip()))
if override is None:
override = override_map.get((index, ""))
if override is not None:
summary = str(override.get("summary") or "").strip()
scene_label = str(override.get("scene_label") or "").strip()
fields = override.get("fields")
if summary:
normalized_item["summary"] = summary
if scene_label:
normalized_item["scene_label"] = scene_label
if isinstance(fields, list):
normalized_item["document_fields"] = [
{
"key": str(field.get("key") or field.get("label") or "").strip(),
"label": str(field.get("label") or "").strip(),
"value": str(field.get("value") or "").strip(),
}
for field in fields
if isinstance(field, dict)
and str(field.get("label") or "").strip()
and str(field.get("value") or "").strip()
]
normalized.append(normalized_item)
return normalized
@staticmethod
def _is_review_association_choice_pending(payload: UserAgentRequest) -> bool:
return bool(payload.tool_payload.get("pending_association_decision"))
def _resolve_review_document_count(self, payload: UserAgentRequest) -> int:
return max(
len(self._resolve_ocr_documents(payload)),
self._resolve_attachment_count(payload),
)
@staticmethod
def _resolve_conversation_history(payload: UserAgentRequest) -> list[dict[str, object]]:
history = payload.context_json.get("conversation_history")
@@ -1852,7 +2032,12 @@ class UserAgentService:
)
return self._build_slot_value()
def _build_reason_slot(self, payload: UserAgentRequest) -> dict[str, str | float]:
def _build_reason_slot(
self,
payload: UserAgentRequest,
*,
claim_groups: list[UserAgentReviewClaimGroup],
) -> dict[str, str | float]:
review_form_values = self._resolve_review_form_values(payload)
edited_value = str(review_form_values.get("reason") or "").strip()
if edited_value:
@@ -1865,7 +2050,7 @@ class UserAgentService:
evidence="来源于用户修改后的结构化表单。",
)
reason_value = self._resolve_reason_text(payload.message)
reason_value = self._resolve_reason_text(self._resolve_reason_source_text(payload))
if reason_value:
return self._build_slot_value(
value=reason_value,
@@ -1875,6 +2060,19 @@ class UserAgentService:
confidence=0.76,
evidence="系统从用户原始描述中提取了本次费用事由,建议继续核对。",
)
inferred_reason = self._infer_reason_from_claim_groups(
claim_groups=claim_groups,
)
if inferred_reason:
return self._build_slot_value(
value=inferred_reason,
raw_value=inferred_reason,
normalized_value=inferred_reason,
source="ocr",
confidence=0.68,
evidence="系统已根据票据识别场景补全通用事由,若需更具体说明可继续修改。",
)
return self._build_slot_value()
def _build_amount_slot(
@@ -2072,7 +2270,10 @@ class UserAgentService:
if primary_expense_type:
scene_codes.add(primary_expense_type)
compact_message = re.sub(r"\s+", "", payload.message)
for scene_code in scene_codes:
required.update(SCENE_REQUIRED_SLOT_KEYS.get(scene_code, set()))
compact_message = re.sub(r"\s+", "", self._resolve_reason_source_text(payload) or payload.message)
if "entertainment" in scene_codes or (
"客户" in compact_message and any(keyword in compact_message for keyword in ("招待", "吃饭", "用餐", "宴请", "请客"))
):
@@ -2080,6 +2281,24 @@ class UserAgentService:
return required
@staticmethod
def _infer_reason_from_claim_groups(
*,
claim_groups: list[UserAgentReviewClaimGroup],
) -> str:
if len(claim_groups) == 1:
document_indexes = list(claim_groups[0].document_indexes or [])
if not document_indexes:
return ""
expense_type = str(claim_groups[0].expense_type or "").strip()
group_code = str(claim_groups[0].group_code or "").strip()
if expense_type:
return INFERRED_REASON_LABELS.get(expense_type, "") or str(claim_groups[0].scene_label or "").strip()
if group_code:
return INFERRED_REASON_LABELS.get(group_code, "") or str(claim_groups[0].scene_label or "").strip()
return ""
@staticmethod
def _resolve_review_missing_slot_keys(
payload: UserAgentRequest,
@@ -2087,6 +2306,7 @@ class UserAgentService:
slot_cards: list[UserAgentReviewSlotCard],
) -> list[str]:
required_keys = {item.key for item in slot_cards if item.required}
slot_map = {item.key: item for item in slot_cards}
missing_keys = {
item.key
for item in slot_cards
@@ -2094,7 +2314,15 @@ class UserAgentService:
}
for key in payload.ontology.missing_slots:
normalized_key = str(key or "").strip()
if normalized_key and normalized_key in required_keys:
if (
normalized_key
and normalized_key in required_keys
and (
normalized_key not in slot_map
or slot_map[normalized_key].status == "missing"
or not str(slot_map[normalized_key].value).strip()
)
):
missing_keys.add(normalized_key)
ordered_keys: list[str] = []
@@ -2257,35 +2485,104 @@ class UserAgentService:
def _extract_document_fields(self, item: dict[str, object]) -> dict[str, str]:
raw_fields = item.get("document_fields")
normalized_fields: dict[str, str] = {}
if isinstance(raw_fields, list):
normalized_fields: dict[str, str] = {}
for field in raw_fields:
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip()
label = str(field.get("label") or "").strip()
value = str(field.get("value") or "").strip()
if label and value:
normalized_fields[label] = value
if normalized_fields:
return normalized_fields
if not value:
continue
normalized_label = self._normalize_document_field_label(key=key, label=label)
display_label = normalized_label or label
normalized_value = self._normalize_document_field_value(
label=display_label,
value=value,
)
if display_label and normalized_value:
normalized_fields.setdefault(display_label, normalized_value)
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
fields: dict[str, str] = {}
amount_match = AMOUNT_TEXT_PATTERN.search(text)
if amount_match:
fields["金额"] = f"{amount_match.group(1)}"
amount_value = self._extract_amount_text_from_value(text)
if amount_value and "金额" not in normalized_fields:
normalized_fields["金额"] = amount_value
date_match = DATE_TEXT_PATTERN.search(text)
if date_match:
fields["时间"] = date_match.group(1)
if date_match and "时间" not in normalized_fields:
normalized_fields["时间"] = date_match.group(1)
merchant = self._extract_document_merchant_name(item)
if merchant:
fields["商户/酒店"] = merchant
return fields
merchant = self._extract_document_merchant_name_from_text(text)
if merchant and "商户/酒店" not in normalized_fields:
normalized_fields["商户/酒店"] = merchant
return normalized_fields
@staticmethod
def _extract_document_merchant_name(item: dict[str, object]) -> str:
def _normalize_document_field_label(*, key: str, label: str) -> str:
compact_key = str(key or "").strip().lower().replace("_", "")
compact_label = str(label or "").replace(" ", "")
if compact_key in {
"amount",
"totalamount",
"paymentamount",
"paidamount",
"actualamount",
} or any(
token in compact_label
for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return "金额"
if compact_key in {"date", "time", "issuedat", "invoicedate"} or any(
token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")
):
return "时间"
if compact_key in {"merchant", "merchantname", "sellername", "vendorname"} or any(
token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方")
):
return "商户/酒店"
return label
def _normalize_document_field_value(self, *, label: str, value: str) -> str:
normalized_label = str(label or "").strip()
raw_value = str(value or "").strip()
if not normalized_label or not raw_value:
return ""
if normalized_label == "金额":
return self._extract_amount_text_from_value(raw_value) or raw_value
if normalized_label == "时间":
match = DATE_TEXT_PATTERN.search(raw_value)
return match.group(1) if match else raw_value
return raw_value
def _extract_amount_text_from_value(self, value: str) -> str:
raw_value = str(value or "").strip()
if not raw_value:
return ""
best_amount: Decimal | None = None
for pattern in (DOCUMENT_AMOUNT_PATTERN, DOCUMENT_CURRENCY_AMOUNT_PATTERN, AMOUNT_TEXT_PATTERN):
for match in pattern.finditer(raw_value):
try:
candidate = Decimal(str(match.group(1)).replace(",", "."))
except (InvalidOperation, TypeError):
continue
if candidate <= Decimal("0.00"):
continue
if best_amount is None or candidate > best_amount:
best_amount = candidate
if best_amount is None:
return ""
return f"{best_amount.quantize(Decimal('0.01')):.2f}"
def _extract_document_merchant_name(self, item: dict[str, object]) -> str:
fields = self._extract_document_fields(item)
merchant = str(fields.get("商户/酒店") or "").strip()
if merchant:
return merchant
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
return self._extract_document_merchant_name_from_text(text)
@staticmethod
def _extract_document_merchant_name_from_text(text: str) -> str:
for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"):
if keyword in text:
return keyword
@@ -2297,7 +2594,8 @@ class UserAgentService:
if item.label != "金额":
continue
try:
return float(str(item.value).replace("", "").strip())
normalized_value = str(item.value).replace("", "").replace("", "").replace("¥", "").strip()
return float(normalized_value)
except ValueError:
return 0.0
return 0.0
@@ -2315,7 +2613,7 @@ class UserAgentService:
total = 0.0
for item in ocr_documents:
fields = self._extract_document_fields(item)
amount_text = str(fields.get("金额") or "").replace("", "").strip()
amount_text = str(fields.get("金额") or "").replace("", "").replace("", "").replace("¥", "").strip()
if not amount_text:
continue
try:

View File

@@ -46,7 +46,7 @@ def test_user_agent_query_returns_readable_answer_and_actions() -> None:
assert len(response.suggested_actions) >= 1
def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None:
def test_user_agent_returns_readable_query_answer_when_runtime_model_is_skipped(monkeypatch) -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
@@ -56,11 +56,7 @@ def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) ->
)
)
service = UserAgentService(db)
monkeypatch.setattr(
service,
"_generate_answer_with_model",
lambda *args, **kwargs: "这是模型回答",
)
monkeypatch.setattr(service, "_generate_answer_with_model", lambda *args, **kwargs: "这是模型回答")
response = service.respond(
UserAgentRequest(
@@ -72,7 +68,8 @@ def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) ->
)
)
assert response.answer == "这是模型回答"
assert "共 2 笔" in response.answer
assert "8800.00" in response.answer
def test_user_agent_sanitizes_model_thinking_blocks() -> None:
@@ -144,7 +141,7 @@ def test_user_agent_guides_implicit_expense_draft_request() -> None:
assert response.review_payload is not None
assert response.answer == response.review_payload.body_message
assert response.review_payload.intent_summary.startswith("我理解你这次想报销业务招待费。")
assert response.review_payload.intent_summary.startswith("识别到您希望报销一笔“业务招待费”费用")
assert response.review_payload.missing_slots == ["客户名称", "参与人员", "票据附件"]
assert [item.action_type for item in response.review_payload.confirmation_actions] == [
"cancel_review",
@@ -187,7 +184,102 @@ def test_user_agent_guides_narrative_with_day_before_yesterday() -> None:
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["time_range"].raw_value == "前天"
assert slot_map["time_range"].value == "2026-05-11"
assert "时间2026-05-11" in response.review_payload.intent_summary
assert "时间2026-05-11" in response.review_payload.intent_summary
def test_user_agent_attachment_only_upload_uses_generic_scene_reason_without_fabrication() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。",
user_id="pytest",
context_json={
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 订单金额 32 元",
"text": "滴滴出行 订单金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
}
],
"user_input_text": "",
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。\n附件名称didi-trip.png",
ontology=ontology,
context_json={
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 订单金额 32 元",
"text": "滴滴出行 订单金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
}
],
"user_input_text": "",
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["reason"].value == "交通出行"
assert slot_map["reason"].status == "inferred"
def test_user_agent_transport_flow_infers_reason_and_does_not_require_location_or_merchant() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了交通票据,帮我生成报销草稿",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了交通票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 支付金额 32 元",
"text": "滴滴出行 支付金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
"scene_label": "交通票据",
}
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["reason"].value == "交通出行"
assert slot_map["reason"].status == "inferred"
assert "酒店/商户" not in response.review_payload.missing_slots
assert "地点" not in response.review_payload.missing_slots
assert "事由说明" not in response.review_payload.missing_slots
def test_user_agent_risk_response_includes_rule_citations() -> None:
@@ -347,7 +439,238 @@ def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> N
"save_draft",
]
assert any(item.scene_label == "业务招待费" for item in response.review_payload.document_cards)
assert f"时间{yesterday}" in response.review_payload.intent_summary
assert f"时间{yesterday}" in response.review_payload.intent_summary
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["time_range"].value == yesterday
assert slot_map["time_range"].raw_value == "昨天"
def test_user_agent_sums_multi_document_amounts_from_synonym_fields() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张交通票据,帮我生成报销草稿",
user_id="pytest",
context_json={
"attachment_names": ["滴滴行程单.png", "停车票.jpg"],
"attachment_count": 2,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 订单金额 ¥32.50",
"avg_score": 0.94,
"document_fields": [
{"key": "amount", "label": "支付金额", "value": "32.50"},
],
"warnings": [],
},
{
"filename": "停车票.jpg",
"summary": "停车票",
"text": "停车费 合计 18 元",
"avg_score": 0.92,
"document_fields": [
{"key": "total_amount", "label": "合计金额", "value": "18"},
],
"warnings": [],
},
],
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了两张交通票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png", "停车票.jpg"],
"attachment_count": 2,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 订单金额 ¥32.50",
"avg_score": 0.94,
"document_fields": [
{"key": "amount", "label": "支付金额", "value": "32.50"},
],
"warnings": [],
},
{
"filename": "停车票.jpg",
"summary": "停车票",
"text": "停车费 合计 18 元",
"avg_score": 0.92,
"document_fields": [
{"key": "total_amount", "label": "合计金额", "value": "18"},
],
"warnings": [],
},
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["amount"].value == "50.50元"
document_field_labels = [
field.label
for card in response.review_payload.document_cards
for field in card.fields
]
assert "金额" in document_field_labels
def test_user_agent_prefers_larger_decimal_amount_from_ocr_text_candidates() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了打车票据,帮我生成报销草稿",
user_id="pytest",
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678",
"avg_score": 0.94,
"warnings": [],
},
],
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了打车票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678",
"avg_score": 0.94,
"warnings": [],
},
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["amount"].value == "13.40元"
def test_user_agent_review_payload_keeps_document_preview_data() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了打车票据,帮我生成报销草稿",
user_id="pytest",
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 实付 13.4 元",
"avg_score": 0.94,
"preview_kind": "image",
"preview_data_url": "data:image/png;base64,ZmFrZQ==",
"warnings": [],
},
],
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了打车票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 实付 13.4 元",
"avg_score": 0.94,
"preview_kind": "image",
"preview_data_url": "data:image/png;base64,ZmFrZQ==",
"warnings": [],
},
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
assert response.review_payload.document_cards[0].preview_kind == "image"
assert response.review_payload.document_cards[0].preview_data_url.startswith("data:image/png;base64,")
def test_user_agent_prompts_existing_draft_association_choice_for_multi_documents() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张票据,帮我生成报销草稿",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了两张票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png", "餐饮发票.jpg"],
"attachment_count": 2,
"ocr_documents": [
{"filename": "滴滴行程单.png", "summary": "滴滴出行 金额 32 元", "text": "滴滴出行 金额 32 元"},
{"filename": "餐饮发票.jpg", "summary": "餐饮发票 金额 68 元", "text": "餐饮发票 金额 68 元"},
],
},
tool_payload={
"pending_association_decision": True,
"association_candidate_claim_no": "EXP-202605-008",
},
)
)
assert response.review_payload is not None
assert response.review_payload.can_proceed is False
assert [item.action_type for item in response.review_payload.confirmation_actions] == [
"cancel_review",
"edit_review",
"link_to_existing_draft",
"create_new_claim_from_documents",
]
assert "EXP-202605-008" in response.answer