feat(server): 更新用户代理服务架构,增强用户行为追踪和会话管理功能,包含schema、service和单元测试

This commit is contained in:
caoxiaozhu
2026-05-14 15:42:33 +00:00
parent fad583ee7c
commit ad16358e71
3 changed files with 664 additions and 41 deletions

View File

@@ -117,6 +117,8 @@ class UserAgentReviewDocumentCard(BaseModel):
scene_label: str = Field(default="", description="面向用户展示的场景标签。") scene_label: str = Field(default="", description="面向用户展示的场景标签。")
summary: str = Field(default="", description="逐票据摘要。") summary: str = Field(default="", description="逐票据摘要。")
avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="OCR 平均得分。") avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="OCR 平均得分。")
preview_kind: str = Field(default="", description="票据预览类型,例如 image。")
preview_data_url: str = Field(default="", description="票据预览图片 data URL。")
warnings: list[str] = Field(default_factory=list, description="该票据的识别提示。") warnings: list[str] = Field(default_factory=list, description="该票据的识别提示。")
fields: list[UserAgentReviewDocumentField] = Field( fields: list[UserAgentReviewDocumentField] = Field(
default_factory=list, default_factory=list,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import re import re
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from decimal import Decimal, InvalidOperation
from sqlalchemy import or_, select from sqlalchemy import or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -118,6 +119,11 @@ SLOT_LABELS = {
DATE_TEXT_PATTERN = re.compile(r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?)") DATE_TEXT_PATTERN = re.compile(r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?)")
AMOUNT_TEXT_PATTERN = re.compile(r"(\d+(?:\.\d+)?)\s*(?:元|万元|万)") AMOUNT_TEXT_PATTERN = re.compile(r"(\d+(?:\.\d+)?)\s*(?:元|万元|万)")
DOCUMENT_AMOUNT_PATTERN = re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
)
DOCUMENT_CURRENCY_AMOUNT_PATTERN = re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)")
SOURCE_LABELS = { SOURCE_LABELS = {
"user_text": "用户描述", "user_text": "用户描述",
@@ -130,6 +136,36 @@ SOURCE_LABELS = {
"system": "系统判断", "system": "系统判断",
} }
SCENE_REQUIRED_SLOT_KEYS = {
"hotel": {"merchant_name"},
"meeting": {"location"},
"entertainment": {"location", "customer_name", "participants"},
}
INFERRED_REASON_LABELS = {
"travel": "出差行程",
"hotel": "住宿报销",
"transport": "交通出行",
"meal": "餐饮用餐",
"meeting": "会务活动",
"entertainment": "客户接待",
"office": "办公采购",
"training": "培训学习",
"communication": "通讯使用",
"welfare": "员工福利",
"other": "其他费用",
}
SYSTEM_GENERATED_REASON_PREFIXES = (
"我上传了",
"请按当前已识别信息",
"请把当前上传的票据",
"请基于当前上传的多张票据",
"我已核对右侧识别结果",
"请同步修正逐票据识别结果",
"我已修改识别信息",
"查看报销草稿",
"请解释一下当前这笔报销的合规风险和待补充项",
)
class UserAgentService: class UserAgentService:
def __init__(self, db: Session) -> None: def __init__(self, db: Session) -> None:
@@ -736,10 +772,15 @@ class UserAgentService:
document_cards=document_cards, document_cards=document_cards,
claim_groups=claim_groups, claim_groups=claim_groups,
) )
can_proceed = self._can_proceed_review( association_choice_pending = self._is_review_association_choice_pending(payload)
payload, can_proceed = (
missing_slot_keys=missing_slot_keys, False
claim_groups=claim_groups, if association_choice_pending
else self._can_proceed_review(
payload,
missing_slot_keys=missing_slot_keys,
claim_groups=claim_groups,
)
) )
confirmation_actions = self._build_review_confirmation_actions( confirmation_actions = self._build_review_confirmation_actions(
payload, payload,
@@ -762,6 +803,7 @@ class UserAgentService:
slot_cards=slot_cards, slot_cards=slot_cards,
risk_briefs=risk_briefs, risk_briefs=risk_briefs,
can_proceed=can_proceed, can_proceed=can_proceed,
document_cards=document_cards,
) )
return UserAgentReviewPayload( return UserAgentReviewPayload(
@@ -798,7 +840,10 @@ class UserAgentService:
ocr_documents=ocr_documents, ocr_documents=ocr_documents,
) )
merchant_slot = self._build_merchant_slot(payload, ocr_documents=ocr_documents) merchant_slot = self._build_merchant_slot(payload, ocr_documents=ocr_documents)
reason_slot = self._build_reason_slot(payload) reason_slot = self._build_reason_slot(
payload,
claim_groups=claim_groups,
)
attachment_slot = self._build_attachment_slot(payload) attachment_slot = self._build_attachment_slot(payload)
required_keys = self._resolve_required_review_keys( required_keys = self._resolve_required_review_keys(
payload, payload,
@@ -922,6 +967,8 @@ class UserAgentService:
), ),
summary=str(item.get("summary") or item.get("text") or "").strip(), summary=str(item.get("summary") or item.get("text") or "").strip(),
avg_score=float(item.get("avg_score") or 0.0), avg_score=float(item.get("avg_score") or 0.0),
preview_kind=str(item.get("preview_kind") or "").strip(),
preview_data_url=str(item.get("preview_data_url") or "").strip(),
warnings=[str(warning) for warning in item.get("warnings", []) if str(warning).strip()], warnings=[str(warning) for warning in item.get("warnings", []) if str(warning).strip()],
fields=[ fields=[
UserAgentReviewDocumentField( UserAgentReviewDocumentField(
@@ -950,14 +997,22 @@ class UserAgentService:
{ {
"document_indexes": [], "document_indexes": [],
"amount_total": 0.0, "amount_total": 0.0,
"expense_type": group_code, "expense_type": str(card.suggested_expense_type or group_code).strip() or group_code,
"scene_label": GROUP_SCENE_LABELS.get(group_code, "其他费用"), "scene_label": GROUP_SCENE_LABELS.get(
str(card.suggested_expense_type or group_code).strip() or group_code,
GROUP_SCENE_LABELS.get(group_code, "其他费用"),
),
"reasons": [], "reasons": [],
}, },
) )
bucket["document_indexes"].append(card.index) bucket["document_indexes"].append(card.index)
bucket["amount_total"] = float(bucket["amount_total"]) + self._extract_amount_from_card(card) bucket["amount_total"] = float(bucket["amount_total"]) + self._extract_amount_from_card(card)
bucket["reasons"].append(f"{card.filename} 识别为 {card.scene_label}") bucket["reasons"].append(f"{card.filename} 识别为 {card.scene_label}")
current_expense_type = str(bucket["expense_type"] or "").strip()
current_card_type = str(card.suggested_expense_type or "").strip()
if current_expense_type and current_card_type and current_expense_type != current_card_type:
bucket["expense_type"] = group_code
bucket["scene_label"] = GROUP_SCENE_LABELS.get(group_code, "其他费用")
if not groups: if not groups:
expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "other") expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "other")
@@ -1080,6 +1135,40 @@ class UserAgentService:
claim_groups: list[UserAgentReviewClaimGroup], claim_groups: list[UserAgentReviewClaimGroup],
draft_payload: UserAgentDraftPayload | None, draft_payload: UserAgentDraftPayload | None,
) -> list[UserAgentReviewAction]: ) -> list[UserAgentReviewAction]:
if self._is_review_association_choice_pending(payload):
claim_no = str(payload.tool_payload.get("association_candidate_claim_no") or "").strip()
link_label = f"关联到草稿 {claim_no}" if claim_no else "关联到现有草稿"
return [
UserAgentReviewAction(
label="取消",
action_type="cancel_review",
description="放弃当前识别结果,并退出本次核对流程。",
emphasis="secondary",
),
UserAgentReviewAction(
label="修改识别信息",
action_type="edit_review",
description="打开结构化模板,按已识别字段逐项修改。",
emphasis="secondary",
),
UserAgentReviewAction(
label=link_label,
action_type="link_to_existing_draft",
description=(
f"把本次上传票据并入现有草稿 {claim_no}"
if claim_no
else "把本次上传票据并入现有草稿。"
),
emphasis="primary",
),
UserAgentReviewAction(
label="单独建立报销单",
action_type="create_new_claim_from_documents",
description="基于当前上传的多张票据,新建一张独立的报销草稿。",
emphasis="secondary",
),
]
primary_action = UserAgentReviewAction( primary_action = UserAgentReviewAction(
label="继续下一步" if can_proceed else "保存为草稿", label="继续下一步" if can_proceed else "保存为草稿",
action_type="next_step" if can_proceed else "save_draft", action_type="next_step" if can_proceed else "save_draft",
@@ -1171,6 +1260,22 @@ class UserAgentService:
"后续您可以继续补充缺失项,或修改识别结果后再继续提交。" "后续您可以继续补充缺失项,或修改识别结果后再继续提交。"
) )
return "已按您当前确认的信息保存为草稿。后续您可以继续补充缺失项,或修改识别结果后再继续提交。" return "已按您当前确认的信息保存为草稿。后续您可以继续补充缺失项,或修改识别结果后再继续提交。"
if review_action == "link_to_existing_draft":
document_count = self._resolve_review_document_count(payload)
if draft_payload is not None and draft_payload.claim_no:
return (
f"已将本次上传的 {document_count} 张票据关联到草稿 {draft_payload.claim_no}"
"您可以继续补充识别字段,确认无误后再提交审批。"
)
return "已将本次上传的票据关联到现有草稿。您可以继续补充识别字段,确认无误后再提交审批。"
if review_action == "create_new_claim_from_documents":
document_count = self._resolve_review_document_count(payload)
if draft_payload is not None and draft_payload.claim_no:
return (
f"已按当前上传的 {document_count} 张票据新建报销草稿 {draft_payload.claim_no}"
"您可以继续补充识别字段,确认无误后再提交审批。"
)
return "已按当前上传票据新建报销草稿。您可以继续补充识别字段,确认无误后再提交审批。"
if review_action == "next_step": if review_action == "next_step":
if draft_payload is not None and draft_payload.status == "submitted": if draft_payload is not None and draft_payload.status == "submitted":
stage_text = draft_payload.approval_stage or "审批中" stage_text = draft_payload.approval_stage or "审批中"
@@ -1195,7 +1300,21 @@ class UserAgentService:
slot_cards: list[UserAgentReviewSlotCard], slot_cards: list[UserAgentReviewSlotCard],
risk_briefs: list[UserAgentReviewRiskBrief], risk_briefs: list[UserAgentReviewRiskBrief],
can_proceed: bool, can_proceed: bool,
document_cards: list[UserAgentReviewDocumentCard],
) -> str: ) -> str:
if self._is_review_association_choice_pending(payload):
claim_no = str(payload.tool_payload.get("association_candidate_claim_no") or "").strip()
document_count = len(document_cards) or self._resolve_review_document_count(payload)
if claim_no:
return (
f"已识别出本次上传的 {document_count} 张票据。"
f"系统检测到你已有草稿 {claim_no},请选择关联到该草稿,或单独建立一张新的报销单。"
)
return (
f"已识别出本次上传的 {document_count} 张票据。"
"系统检测到你已有可用草稿,请先选择关联到现有草稿,或单独建立一张新的报销单。"
)
review_payload = UserAgentReviewPayload( review_payload = UserAgentReviewPayload(
intent_summary="", intent_summary="",
body_message="", body_message="",
@@ -1423,6 +1542,22 @@ class UserAgentService:
return cleaned[:300] return cleaned[:300]
return "" return ""
@staticmethod
def _looks_like_system_generated_reason_message(message: str) -> bool:
cleaned = str(message or "").strip()
if not cleaned:
return False
compact = re.sub(r"\s+", "", cleaned)
return compact.startswith(SYSTEM_GENERATED_REASON_PREFIXES)
def _resolve_reason_source_text(self, payload: UserAgentRequest) -> str:
explicit_text = payload.context_json.get("user_input_text")
if isinstance(explicit_text, str):
return explicit_text.strip()
if self._looks_like_system_generated_reason_message(payload.message):
return ""
return str(payload.message or "").strip()
@classmethod @classmethod
def _resolve_reason_text(cls, message: str) -> str: def _resolve_reason_text(cls, message: str) -> str:
reason = cls._extract_message_reason(message) reason = cls._extract_message_reason(message)
@@ -1553,13 +1688,58 @@ class UserAgentService:
documents = payload.context_json.get("ocr_documents") documents = payload.context_json.get("ocr_documents")
if not isinstance(documents, list): if not isinstance(documents, list):
return [] return []
overrides = payload.context_json.get("review_document_form_values")
override_map: dict[tuple[int, str], dict[str, object]] = {}
if isinstance(overrides, list):
for item in overrides:
if not isinstance(item, dict):
continue
filename = str(item.get("filename") or "").strip()
index = int(item.get("index") or 0)
if not filename and index <= 0:
continue
override_map[(index, filename)] = item
normalized: list[dict[str, object]] = [] normalized: list[dict[str, object]] = []
for item in documents[:8]: for index, item in enumerate(documents[:8], start=1):
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
normalized.append(item) normalized_item = dict(item)
override = override_map.get((index, str(normalized_item.get("filename") or "").strip()))
if override is None:
override = override_map.get((index, ""))
if override is not None:
summary = str(override.get("summary") or "").strip()
scene_label = str(override.get("scene_label") or "").strip()
fields = override.get("fields")
if summary:
normalized_item["summary"] = summary
if scene_label:
normalized_item["scene_label"] = scene_label
if isinstance(fields, list):
normalized_item["document_fields"] = [
{
"key": str(field.get("key") or field.get("label") or "").strip(),
"label": str(field.get("label") or "").strip(),
"value": str(field.get("value") or "").strip(),
}
for field in fields
if isinstance(field, dict)
and str(field.get("label") or "").strip()
and str(field.get("value") or "").strip()
]
normalized.append(normalized_item)
return normalized return normalized
@staticmethod
def _is_review_association_choice_pending(payload: UserAgentRequest) -> bool:
return bool(payload.tool_payload.get("pending_association_decision"))
def _resolve_review_document_count(self, payload: UserAgentRequest) -> int:
return max(
len(self._resolve_ocr_documents(payload)),
self._resolve_attachment_count(payload),
)
@staticmethod @staticmethod
def _resolve_conversation_history(payload: UserAgentRequest) -> list[dict[str, object]]: def _resolve_conversation_history(payload: UserAgentRequest) -> list[dict[str, object]]:
history = payload.context_json.get("conversation_history") history = payload.context_json.get("conversation_history")
@@ -1852,7 +2032,12 @@ class UserAgentService:
) )
return self._build_slot_value() return self._build_slot_value()
def _build_reason_slot(self, payload: UserAgentRequest) -> dict[str, str | float]: def _build_reason_slot(
self,
payload: UserAgentRequest,
*,
claim_groups: list[UserAgentReviewClaimGroup],
) -> dict[str, str | float]:
review_form_values = self._resolve_review_form_values(payload) review_form_values = self._resolve_review_form_values(payload)
edited_value = str(review_form_values.get("reason") or "").strip() edited_value = str(review_form_values.get("reason") or "").strip()
if edited_value: if edited_value:
@@ -1865,7 +2050,7 @@ class UserAgentService:
evidence="来源于用户修改后的结构化表单。", evidence="来源于用户修改后的结构化表单。",
) )
reason_value = self._resolve_reason_text(payload.message) reason_value = self._resolve_reason_text(self._resolve_reason_source_text(payload))
if reason_value: if reason_value:
return self._build_slot_value( return self._build_slot_value(
value=reason_value, value=reason_value,
@@ -1875,6 +2060,19 @@ class UserAgentService:
confidence=0.76, confidence=0.76,
evidence="系统从用户原始描述中提取了本次费用事由,建议继续核对。", evidence="系统从用户原始描述中提取了本次费用事由,建议继续核对。",
) )
inferred_reason = self._infer_reason_from_claim_groups(
claim_groups=claim_groups,
)
if inferred_reason:
return self._build_slot_value(
value=inferred_reason,
raw_value=inferred_reason,
normalized_value=inferred_reason,
source="ocr",
confidence=0.68,
evidence="系统已根据票据识别场景补全通用事由,若需更具体说明可继续修改。",
)
return self._build_slot_value() return self._build_slot_value()
def _build_amount_slot( def _build_amount_slot(
@@ -2072,7 +2270,10 @@ class UserAgentService:
if primary_expense_type: if primary_expense_type:
scene_codes.add(primary_expense_type) scene_codes.add(primary_expense_type)
compact_message = re.sub(r"\s+", "", payload.message) for scene_code in scene_codes:
required.update(SCENE_REQUIRED_SLOT_KEYS.get(scene_code, set()))
compact_message = re.sub(r"\s+", "", self._resolve_reason_source_text(payload) or payload.message)
if "entertainment" in scene_codes or ( if "entertainment" in scene_codes or (
"客户" in compact_message and any(keyword in compact_message for keyword in ("招待", "吃饭", "用餐", "宴请", "请客")) "客户" in compact_message and any(keyword in compact_message for keyword in ("招待", "吃饭", "用餐", "宴请", "请客"))
): ):
@@ -2080,6 +2281,24 @@ class UserAgentService:
return required return required
@staticmethod
def _infer_reason_from_claim_groups(
*,
claim_groups: list[UserAgentReviewClaimGroup],
) -> str:
if len(claim_groups) == 1:
document_indexes = list(claim_groups[0].document_indexes or [])
if not document_indexes:
return ""
expense_type = str(claim_groups[0].expense_type or "").strip()
group_code = str(claim_groups[0].group_code or "").strip()
if expense_type:
return INFERRED_REASON_LABELS.get(expense_type, "") or str(claim_groups[0].scene_label or "").strip()
if group_code:
return INFERRED_REASON_LABELS.get(group_code, "") or str(claim_groups[0].scene_label or "").strip()
return ""
@staticmethod @staticmethod
def _resolve_review_missing_slot_keys( def _resolve_review_missing_slot_keys(
payload: UserAgentRequest, payload: UserAgentRequest,
@@ -2087,6 +2306,7 @@ class UserAgentService:
slot_cards: list[UserAgentReviewSlotCard], slot_cards: list[UserAgentReviewSlotCard],
) -> list[str]: ) -> list[str]:
required_keys = {item.key for item in slot_cards if item.required} required_keys = {item.key for item in slot_cards if item.required}
slot_map = {item.key: item for item in slot_cards}
missing_keys = { missing_keys = {
item.key item.key
for item in slot_cards for item in slot_cards
@@ -2094,7 +2314,15 @@ class UserAgentService:
} }
for key in payload.ontology.missing_slots: for key in payload.ontology.missing_slots:
normalized_key = str(key or "").strip() normalized_key = str(key or "").strip()
if normalized_key and normalized_key in required_keys: if (
normalized_key
and normalized_key in required_keys
and (
normalized_key not in slot_map
or slot_map[normalized_key].status == "missing"
or not str(slot_map[normalized_key].value).strip()
)
):
missing_keys.add(normalized_key) missing_keys.add(normalized_key)
ordered_keys: list[str] = [] ordered_keys: list[str] = []
@@ -2257,35 +2485,104 @@ class UserAgentService:
def _extract_document_fields(self, item: dict[str, object]) -> dict[str, str]: def _extract_document_fields(self, item: dict[str, object]) -> dict[str, str]:
raw_fields = item.get("document_fields") raw_fields = item.get("document_fields")
normalized_fields: dict[str, str] = {}
if isinstance(raw_fields, list): if isinstance(raw_fields, list):
normalized_fields: dict[str, str] = {}
for field in raw_fields: for field in raw_fields:
if not isinstance(field, dict): if not isinstance(field, dict):
continue continue
key = str(field.get("key") or "").strip()
label = str(field.get("label") or "").strip() label = str(field.get("label") or "").strip()
value = str(field.get("value") or "").strip() value = str(field.get("value") or "").strip()
if label and value: if not value:
normalized_fields[label] = value continue
if normalized_fields: normalized_label = self._normalize_document_field_label(key=key, label=label)
return normalized_fields display_label = normalized_label or label
normalized_value = self._normalize_document_field_value(
label=display_label,
value=value,
)
if display_label and normalized_value:
normalized_fields.setdefault(display_label, normalized_value)
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
fields: dict[str, str] = {} amount_value = self._extract_amount_text_from_value(text)
amount_match = AMOUNT_TEXT_PATTERN.search(text) if amount_value and "金额" not in normalized_fields:
if amount_match: normalized_fields["金额"] = amount_value
fields["金额"] = f"{amount_match.group(1)}"
date_match = DATE_TEXT_PATTERN.search(text) date_match = DATE_TEXT_PATTERN.search(text)
if date_match: if date_match and "时间" not in normalized_fields:
fields["时间"] = date_match.group(1) normalized_fields["时间"] = date_match.group(1)
merchant = self._extract_document_merchant_name(item) merchant = self._extract_document_merchant_name_from_text(text)
if merchant: if merchant and "商户/酒店" not in normalized_fields:
fields["商户/酒店"] = merchant normalized_fields["商户/酒店"] = merchant
return fields return normalized_fields
@staticmethod @staticmethod
def _extract_document_merchant_name(item: dict[str, object]) -> str: def _normalize_document_field_label(*, key: str, label: str) -> str:
compact_key = str(key or "").strip().lower().replace("_", "")
compact_label = str(label or "").replace(" ", "")
if compact_key in {
"amount",
"totalamount",
"paymentamount",
"paidamount",
"actualamount",
} or any(
token in compact_label
for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return "金额"
if compact_key in {"date", "time", "issuedat", "invoicedate"} or any(
token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")
):
return "时间"
if compact_key in {"merchant", "merchantname", "sellername", "vendorname"} or any(
token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方")
):
return "商户/酒店"
return label
def _normalize_document_field_value(self, *, label: str, value: str) -> str:
normalized_label = str(label or "").strip()
raw_value = str(value or "").strip()
if not normalized_label or not raw_value:
return ""
if normalized_label == "金额":
return self._extract_amount_text_from_value(raw_value) or raw_value
if normalized_label == "时间":
match = DATE_TEXT_PATTERN.search(raw_value)
return match.group(1) if match else raw_value
return raw_value
def _extract_amount_text_from_value(self, value: str) -> str:
raw_value = str(value or "").strip()
if not raw_value:
return ""
best_amount: Decimal | None = None
for pattern in (DOCUMENT_AMOUNT_PATTERN, DOCUMENT_CURRENCY_AMOUNT_PATTERN, AMOUNT_TEXT_PATTERN):
for match in pattern.finditer(raw_value):
try:
candidate = Decimal(str(match.group(1)).replace(",", "."))
except (InvalidOperation, TypeError):
continue
if candidate <= Decimal("0.00"):
continue
if best_amount is None or candidate > best_amount:
best_amount = candidate
if best_amount is None:
return ""
return f"{best_amount.quantize(Decimal('0.01')):.2f}"
def _extract_document_merchant_name(self, item: dict[str, object]) -> str:
fields = self._extract_document_fields(item)
merchant = str(fields.get("商户/酒店") or "").strip()
if merchant:
return merchant
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
return self._extract_document_merchant_name_from_text(text)
@staticmethod
def _extract_document_merchant_name_from_text(text: str) -> str:
for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"): for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"):
if keyword in text: if keyword in text:
return keyword return keyword
@@ -2297,7 +2594,8 @@ class UserAgentService:
if item.label != "金额": if item.label != "金额":
continue continue
try: try:
return float(str(item.value).replace("", "").strip()) normalized_value = str(item.value).replace("", "").replace("", "").replace("¥", "").strip()
return float(normalized_value)
except ValueError: except ValueError:
return 0.0 return 0.0
return 0.0 return 0.0
@@ -2315,7 +2613,7 @@ class UserAgentService:
total = 0.0 total = 0.0
for item in ocr_documents: for item in ocr_documents:
fields = self._extract_document_fields(item) fields = self._extract_document_fields(item)
amount_text = str(fields.get("金额") or "").replace("", "").strip() amount_text = str(fields.get("金额") or "").replace("", "").replace("", "").replace("¥", "").strip()
if not amount_text: if not amount_text:
continue continue
try: try:

View File

@@ -46,7 +46,7 @@ def test_user_agent_query_returns_readable_answer_and_actions() -> None:
assert len(response.suggested_actions) >= 1 assert len(response.suggested_actions) >= 1
def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None: def test_user_agent_returns_readable_query_answer_when_runtime_model_is_skipped(monkeypatch) -> None:
session_factory = build_session_factory() session_factory = build_session_factory()
with session_factory() as db: with session_factory() as db:
ontology = SemanticOntologyService(db).parse( ontology = SemanticOntologyService(db).parse(
@@ -56,11 +56,7 @@ def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) ->
) )
) )
service = UserAgentService(db) service = UserAgentService(db)
monkeypatch.setattr( monkeypatch.setattr(service, "_generate_answer_with_model", lambda *args, **kwargs: "这是模型回答")
service,
"_generate_answer_with_model",
lambda *args, **kwargs: "这是模型回答",
)
response = service.respond( response = service.respond(
UserAgentRequest( UserAgentRequest(
@@ -72,7 +68,8 @@ def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) ->
) )
) )
assert response.answer == "这是模型回答" assert "共 2 笔" in response.answer
assert "8800.00" in response.answer
def test_user_agent_sanitizes_model_thinking_blocks() -> None: def test_user_agent_sanitizes_model_thinking_blocks() -> None:
@@ -144,7 +141,7 @@ def test_user_agent_guides_implicit_expense_draft_request() -> None:
assert response.review_payload is not None assert response.review_payload is not None
assert response.answer == response.review_payload.body_message assert response.answer == response.review_payload.body_message
assert response.review_payload.intent_summary.startswith("我理解你这次想报销业务招待费。") assert response.review_payload.intent_summary.startswith("识别到您希望报销一笔“业务招待费”费用")
assert response.review_payload.missing_slots == ["客户名称", "参与人员", "票据附件"] assert response.review_payload.missing_slots == ["客户名称", "参与人员", "票据附件"]
assert [item.action_type for item in response.review_payload.confirmation_actions] == [ assert [item.action_type for item in response.review_payload.confirmation_actions] == [
"cancel_review", "cancel_review",
@@ -187,7 +184,102 @@ def test_user_agent_guides_narrative_with_day_before_yesterday() -> None:
slot_map = {item.key: item for item in response.review_payload.slot_cards} slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["time_range"].raw_value == "前天" assert slot_map["time_range"].raw_value == "前天"
assert slot_map["time_range"].value == "2026-05-11" assert slot_map["time_range"].value == "2026-05-11"
assert "时间2026-05-11" in response.review_payload.intent_summary assert "时间2026-05-11" in response.review_payload.intent_summary
def test_user_agent_attachment_only_upload_uses_generic_scene_reason_without_fabrication() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。",
user_id="pytest",
context_json={
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 订单金额 32 元",
"text": "滴滴出行 订单金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
}
],
"user_input_text": "",
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。\n附件名称didi-trip.png",
ontology=ontology,
context_json={
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 订单金额 32 元",
"text": "滴滴出行 订单金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
}
],
"user_input_text": "",
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["reason"].value == "交通出行"
assert slot_map["reason"].status == "inferred"
def test_user_agent_transport_flow_infers_reason_and_does_not_require_location_or_merchant() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了交通票据,帮我生成报销草稿",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了交通票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["didi-trip.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "didi-trip.png",
"summary": "滴滴出行 支付金额 32 元",
"text": "滴滴出行 支付金额 32 元",
"document_type": "taxi_receipt",
"scene_code": "transport",
"scene_label": "交通票据",
}
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["reason"].value == "交通出行"
assert slot_map["reason"].status == "inferred"
assert "酒店/商户" not in response.review_payload.missing_slots
assert "地点" not in response.review_payload.missing_slots
assert "事由说明" not in response.review_payload.missing_slots
def test_user_agent_risk_response_includes_rule_citations() -> None: def test_user_agent_risk_response_includes_rule_citations() -> None:
@@ -347,7 +439,238 @@ def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> N
"save_draft", "save_draft",
] ]
assert any(item.scene_label == "业务招待费" for item in response.review_payload.document_cards) assert any(item.scene_label == "业务招待费" for item in response.review_payload.document_cards)
assert f"时间{yesterday}" in response.review_payload.intent_summary assert f"时间{yesterday}" in response.review_payload.intent_summary
slot_map = {item.key: item for item in response.review_payload.slot_cards} slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["time_range"].value == yesterday assert slot_map["time_range"].value == yesterday
assert slot_map["time_range"].raw_value == "昨天" assert slot_map["time_range"].raw_value == "昨天"
def test_user_agent_sums_multi_document_amounts_from_synonym_fields() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张交通票据,帮我生成报销草稿",
user_id="pytest",
context_json={
"attachment_names": ["滴滴行程单.png", "停车票.jpg"],
"attachment_count": 2,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 订单金额 ¥32.50",
"avg_score": 0.94,
"document_fields": [
{"key": "amount", "label": "支付金额", "value": "32.50"},
],
"warnings": [],
},
{
"filename": "停车票.jpg",
"summary": "停车票",
"text": "停车费 合计 18 元",
"avg_score": 0.92,
"document_fields": [
{"key": "total_amount", "label": "合计金额", "value": "18"},
],
"warnings": [],
},
],
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了两张交通票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png", "停车票.jpg"],
"attachment_count": 2,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 订单金额 ¥32.50",
"avg_score": 0.94,
"document_fields": [
{"key": "amount", "label": "支付金额", "value": "32.50"},
],
"warnings": [],
},
{
"filename": "停车票.jpg",
"summary": "停车票",
"text": "停车费 合计 18 元",
"avg_score": 0.92,
"document_fields": [
{"key": "total_amount", "label": "合计金额", "value": "18"},
],
"warnings": [],
},
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["amount"].value == "50.50元"
document_field_labels = [
field.label
for card in response.review_payload.document_cards
for field in card.fields
]
assert "金额" in document_field_labels
def test_user_agent_prefers_larger_decimal_amount_from_ocr_text_candidates() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了打车票据,帮我生成报销草稿",
user_id="pytest",
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678",
"avg_score": 0.94,
"warnings": [],
},
],
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了打车票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678",
"avg_score": 0.94,
"warnings": [],
},
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
slot_map = {item.key: item for item in response.review_payload.slot_cards}
assert slot_map["amount"].value == "13.40元"
def test_user_agent_review_payload_keeps_document_preview_data() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了打车票据,帮我生成报销草稿",
user_id="pytest",
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 实付 13.4 元",
"avg_score": 0.94,
"preview_kind": "image",
"preview_data_url": "data:image/png;base64,ZmFrZQ==",
"warnings": [],
},
],
},
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了打车票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png"],
"attachment_count": 1,
"ocr_documents": [
{
"filename": "滴滴行程单.png",
"summary": "滴滴出行电子行程单",
"text": "滴滴出行 实付 13.4 元",
"avg_score": 0.94,
"preview_kind": "image",
"preview_data_url": "data:image/png;base64,ZmFrZQ==",
"warnings": [],
},
],
},
tool_payload={"draft_only": True},
)
)
assert response.review_payload is not None
assert response.review_payload.document_cards[0].preview_kind == "image"
assert response.review_payload.document_cards[0].preview_data_url.startswith("data:image/png;base64,")
def test_user_agent_prompts_existing_draft_association_choice_for_multi_documents() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上传了两张票据,帮我生成报销草稿",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="我上传了两张票据,帮我生成报销草稿",
ontology=ontology,
context_json={
"attachment_names": ["滴滴行程单.png", "餐饮发票.jpg"],
"attachment_count": 2,
"ocr_documents": [
{"filename": "滴滴行程单.png", "summary": "滴滴出行 金额 32 元", "text": "滴滴出行 金额 32 元"},
{"filename": "餐饮发票.jpg", "summary": "餐饮发票 金额 68 元", "text": "餐饮发票 金额 68 元"},
],
},
tool_payload={
"pending_association_decision": True,
"association_candidate_claim_no": "EXP-202605-008",
},
)
)
assert response.review_payload is not None
assert response.review_payload.can_proceed is False
assert [item.action_type for item in response.review_payload.confirmation_actions] == [
"cancel_review",
"edit_review",
"link_to_existing_draft",
"create_new_claim_from_documents",
]
assert "EXP-202605-008" in response.answer