diff --git a/server/src/app/schemas/user_agent.py b/server/src/app/schemas/user_agent.py index c124a94..e3b538c 100644 --- a/server/src/app/schemas/user_agent.py +++ b/server/src/app/schemas/user_agent.py @@ -117,6 +117,8 @@ class UserAgentReviewDocumentCard(BaseModel): scene_label: str = Field(default="", description="面向用户展示的场景标签。") summary: str = Field(default="", description="逐票据摘要。") avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="OCR 平均得分。") + preview_kind: str = Field(default="", description="票据预览类型,例如 image。") + preview_data_url: str = Field(default="", description="票据预览图片 data URL。") warnings: list[str] = Field(default_factory=list, description="该票据的识别提示。") fields: list[UserAgentReviewDocumentField] = Field( default_factory=list, diff --git a/server/src/app/services/user_agent.py b/server/src/app/services/user_agent.py index 40a29eb..885fe81 100644 --- a/server/src/app/services/user_agent.py +++ b/server/src/app/services/user_agent.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import re from datetime import UTC, datetime, timedelta +from decimal import Decimal, InvalidOperation from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -118,6 +119,11 @@ SLOT_LABELS = { DATE_TEXT_PATTERN = re.compile(r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?)") AMOUNT_TEXT_PATTERN = re.compile(r"(\d+(?:\.\d+)?)\s*(?:元|万元|万)") +DOCUMENT_AMOUNT_PATTERN = re.compile( + r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)" + r"[::\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)" +) +DOCUMENT_CURRENCY_AMOUNT_PATTERN = re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)") SOURCE_LABELS = { "user_text": "用户描述", @@ -130,6 +136,36 @@ SOURCE_LABELS = { "system": "系统判断", } +SCENE_REQUIRED_SLOT_KEYS = { + "hotel": {"merchant_name"}, + "meeting": {"location"}, + "entertainment": {"location", "customer_name", "participants"}, +} +INFERRED_REASON_LABELS = { + "travel": "出差行程", + "hotel": "住宿报销", + "transport": "交通出行", + "meal": "餐饮用餐", + "meeting": "会务活动", + "entertainment": "客户接待", + "office": "办公采购", + "training": "培训学习", + "communication": "通讯使用", + "welfare": "员工福利", + "other": "其他费用", +} +SYSTEM_GENERATED_REASON_PREFIXES = ( + "我上传了", + "请按当前已识别信息", + "请把当前上传的票据", + "请基于当前上传的多张票据", + "我已核对右侧识别结果", + "请同步修正逐票据识别结果", + "我已修改识别信息", + "查看报销草稿", + "请解释一下当前这笔报销的合规风险和待补充项", +) + class UserAgentService: def __init__(self, db: Session) -> None: @@ -736,10 +772,15 @@ class UserAgentService: document_cards=document_cards, claim_groups=claim_groups, ) - can_proceed = self._can_proceed_review( - payload, - missing_slot_keys=missing_slot_keys, - claim_groups=claim_groups, + association_choice_pending = self._is_review_association_choice_pending(payload) + can_proceed = ( + False + if association_choice_pending + else self._can_proceed_review( + payload, + missing_slot_keys=missing_slot_keys, + claim_groups=claim_groups, + ) ) confirmation_actions = self._build_review_confirmation_actions( payload, @@ -762,6 +803,7 @@ class UserAgentService: slot_cards=slot_cards, risk_briefs=risk_briefs, can_proceed=can_proceed, + document_cards=document_cards, ) return UserAgentReviewPayload( @@ -798,7 +840,10 @@ class UserAgentService: ocr_documents=ocr_documents, ) merchant_slot = self._build_merchant_slot(payload, ocr_documents=ocr_documents) - reason_slot = self._build_reason_slot(payload) + reason_slot = self._build_reason_slot( + payload, + claim_groups=claim_groups, + ) attachment_slot = self._build_attachment_slot(payload) required_keys = self._resolve_required_review_keys( payload, @@ -922,6 +967,8 @@ class UserAgentService: ), summary=str(item.get("summary") or item.get("text") or "").strip(), avg_score=float(item.get("avg_score") or 0.0), + preview_kind=str(item.get("preview_kind") or "").strip(), + preview_data_url=str(item.get("preview_data_url") or "").strip(), warnings=[str(warning) for warning in item.get("warnings", []) if str(warning).strip()], fields=[ UserAgentReviewDocumentField( @@ -950,14 +997,22 @@ class UserAgentService: { "document_indexes": [], "amount_total": 0.0, - "expense_type": group_code, - "scene_label": GROUP_SCENE_LABELS.get(group_code, "其他费用"), + "expense_type": str(card.suggested_expense_type or group_code).strip() or group_code, + "scene_label": GROUP_SCENE_LABELS.get( + str(card.suggested_expense_type or group_code).strip() or group_code, + GROUP_SCENE_LABELS.get(group_code, "其他费用"), + ), "reasons": [], }, ) bucket["document_indexes"].append(card.index) bucket["amount_total"] = float(bucket["amount_total"]) + self._extract_amount_from_card(card) bucket["reasons"].append(f"{card.filename} 识别为 {card.scene_label}") + current_expense_type = str(bucket["expense_type"] or "").strip() + current_card_type = str(card.suggested_expense_type or "").strip() + if current_expense_type and current_card_type and current_expense_type != current_card_type: + bucket["expense_type"] = group_code + bucket["scene_label"] = GROUP_SCENE_LABELS.get(group_code, "其他费用") if not groups: expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "other") @@ -1080,6 +1135,40 @@ class UserAgentService: claim_groups: list[UserAgentReviewClaimGroup], draft_payload: UserAgentDraftPayload | None, ) -> list[UserAgentReviewAction]: + if self._is_review_association_choice_pending(payload): + claim_no = str(payload.tool_payload.get("association_candidate_claim_no") or "").strip() + link_label = f"关联到草稿 {claim_no}" if claim_no else "关联到现有草稿" + return [ + UserAgentReviewAction( + label="取消", + action_type="cancel_review", + description="放弃当前识别结果,并退出本次核对流程。", + emphasis="secondary", + ), + UserAgentReviewAction( + label="修改识别信息", + action_type="edit_review", + description="打开结构化模板,按已识别字段逐项修改。", + emphasis="secondary", + ), + UserAgentReviewAction( + label=link_label, + action_type="link_to_existing_draft", + description=( + f"把本次上传票据并入现有草稿 {claim_no}。" + if claim_no + else "把本次上传票据并入现有草稿。" + ), + emphasis="primary", + ), + UserAgentReviewAction( + label="单独建立报销单", + action_type="create_new_claim_from_documents", + description="基于当前上传的多张票据,新建一张独立的报销草稿。", + emphasis="secondary", + ), + ] + primary_action = UserAgentReviewAction( label="继续下一步" if can_proceed else "保存为草稿", action_type="next_step" if can_proceed else "save_draft", @@ -1171,6 +1260,22 @@ class UserAgentService: "后续您可以继续补充缺失项,或修改识别结果后再继续提交。" ) return "已按您当前确认的信息保存为草稿。后续您可以继续补充缺失项,或修改识别结果后再继续提交。" + if review_action == "link_to_existing_draft": + document_count = self._resolve_review_document_count(payload) + if draft_payload is not None and draft_payload.claim_no: + return ( + f"已将本次上传的 {document_count} 张票据关联到草稿 {draft_payload.claim_no}。" + "您可以继续补充识别字段,确认无误后再提交审批。" + ) + return "已将本次上传的票据关联到现有草稿。您可以继续补充识别字段,确认无误后再提交审批。" + if review_action == "create_new_claim_from_documents": + document_count = self._resolve_review_document_count(payload) + if draft_payload is not None and draft_payload.claim_no: + return ( + f"已按当前上传的 {document_count} 张票据新建报销草稿 {draft_payload.claim_no}。" + "您可以继续补充识别字段,确认无误后再提交审批。" + ) + return "已按当前上传票据新建报销草稿。您可以继续补充识别字段,确认无误后再提交审批。" if review_action == "next_step": if draft_payload is not None and draft_payload.status == "submitted": stage_text = draft_payload.approval_stage or "审批中" @@ -1195,7 +1300,21 @@ class UserAgentService: slot_cards: list[UserAgentReviewSlotCard], risk_briefs: list[UserAgentReviewRiskBrief], can_proceed: bool, + document_cards: list[UserAgentReviewDocumentCard], ) -> str: + if self._is_review_association_choice_pending(payload): + claim_no = str(payload.tool_payload.get("association_candidate_claim_no") or "").strip() + document_count = len(document_cards) or self._resolve_review_document_count(payload) + if claim_no: + return ( + f"已识别出本次上传的 {document_count} 张票据。" + f"系统检测到你已有草稿 {claim_no},请选择关联到该草稿,或单独建立一张新的报销单。" + ) + return ( + f"已识别出本次上传的 {document_count} 张票据。" + "系统检测到你已有可用草稿,请先选择关联到现有草稿,或单独建立一张新的报销单。" + ) + review_payload = UserAgentReviewPayload( intent_summary="", body_message="", @@ -1423,6 +1542,22 @@ class UserAgentService: return cleaned[:300] return "" + @staticmethod + def _looks_like_system_generated_reason_message(message: str) -> bool: + cleaned = str(message or "").strip() + if not cleaned: + return False + compact = re.sub(r"\s+", "", cleaned) + return compact.startswith(SYSTEM_GENERATED_REASON_PREFIXES) + + def _resolve_reason_source_text(self, payload: UserAgentRequest) -> str: + explicit_text = payload.context_json.get("user_input_text") + if isinstance(explicit_text, str): + return explicit_text.strip() + if self._looks_like_system_generated_reason_message(payload.message): + return "" + return str(payload.message or "").strip() + @classmethod def _resolve_reason_text(cls, message: str) -> str: reason = cls._extract_message_reason(message) @@ -1553,13 +1688,58 @@ class UserAgentService: documents = payload.context_json.get("ocr_documents") if not isinstance(documents, list): return [] + overrides = payload.context_json.get("review_document_form_values") + override_map: dict[tuple[int, str], dict[str, object]] = {} + if isinstance(overrides, list): + for item in overrides: + if not isinstance(item, dict): + continue + filename = str(item.get("filename") or "").strip() + index = int(item.get("index") or 0) + if not filename and index <= 0: + continue + override_map[(index, filename)] = item normalized: list[dict[str, object]] = [] - for item in documents[:8]: + for index, item in enumerate(documents[:8], start=1): if not isinstance(item, dict): continue - normalized.append(item) + normalized_item = dict(item) + override = override_map.get((index, str(normalized_item.get("filename") or "").strip())) + if override is None: + override = override_map.get((index, "")) + if override is not None: + summary = str(override.get("summary") or "").strip() + scene_label = str(override.get("scene_label") or "").strip() + fields = override.get("fields") + if summary: + normalized_item["summary"] = summary + if scene_label: + normalized_item["scene_label"] = scene_label + if isinstance(fields, list): + normalized_item["document_fields"] = [ + { + "key": str(field.get("key") or field.get("label") or "").strip(), + "label": str(field.get("label") or "").strip(), + "value": str(field.get("value") or "").strip(), + } + for field in fields + if isinstance(field, dict) + and str(field.get("label") or "").strip() + and str(field.get("value") or "").strip() + ] + normalized.append(normalized_item) return normalized + @staticmethod + def _is_review_association_choice_pending(payload: UserAgentRequest) -> bool: + return bool(payload.tool_payload.get("pending_association_decision")) + + def _resolve_review_document_count(self, payload: UserAgentRequest) -> int: + return max( + len(self._resolve_ocr_documents(payload)), + self._resolve_attachment_count(payload), + ) + @staticmethod def _resolve_conversation_history(payload: UserAgentRequest) -> list[dict[str, object]]: history = payload.context_json.get("conversation_history") @@ -1852,7 +2032,12 @@ class UserAgentService: ) return self._build_slot_value() - def _build_reason_slot(self, payload: UserAgentRequest) -> dict[str, str | float]: + def _build_reason_slot( + self, + payload: UserAgentRequest, + *, + claim_groups: list[UserAgentReviewClaimGroup], + ) -> dict[str, str | float]: review_form_values = self._resolve_review_form_values(payload) edited_value = str(review_form_values.get("reason") or "").strip() if edited_value: @@ -1865,7 +2050,7 @@ class UserAgentService: evidence="来源于用户修改后的结构化表单。", ) - reason_value = self._resolve_reason_text(payload.message) + reason_value = self._resolve_reason_text(self._resolve_reason_source_text(payload)) if reason_value: return self._build_slot_value( value=reason_value, @@ -1875,6 +2060,19 @@ class UserAgentService: confidence=0.76, evidence="系统从用户原始描述中提取了本次费用事由,建议继续核对。", ) + + inferred_reason = self._infer_reason_from_claim_groups( + claim_groups=claim_groups, + ) + if inferred_reason: + return self._build_slot_value( + value=inferred_reason, + raw_value=inferred_reason, + normalized_value=inferred_reason, + source="ocr", + confidence=0.68, + evidence="系统已根据票据识别场景补全通用事由,若需更具体说明可继续修改。", + ) return self._build_slot_value() def _build_amount_slot( @@ -2072,7 +2270,10 @@ class UserAgentService: if primary_expense_type: scene_codes.add(primary_expense_type) - compact_message = re.sub(r"\s+", "", payload.message) + for scene_code in scene_codes: + required.update(SCENE_REQUIRED_SLOT_KEYS.get(scene_code, set())) + + compact_message = re.sub(r"\s+", "", self._resolve_reason_source_text(payload) or payload.message) if "entertainment" in scene_codes or ( "客户" in compact_message and any(keyword in compact_message for keyword in ("招待", "吃饭", "用餐", "宴请", "请客")) ): @@ -2080,6 +2281,24 @@ class UserAgentService: return required + @staticmethod + def _infer_reason_from_claim_groups( + *, + claim_groups: list[UserAgentReviewClaimGroup], + ) -> str: + if len(claim_groups) == 1: + document_indexes = list(claim_groups[0].document_indexes or []) + if not document_indexes: + return "" + + expense_type = str(claim_groups[0].expense_type or "").strip() + group_code = str(claim_groups[0].group_code or "").strip() + if expense_type: + return INFERRED_REASON_LABELS.get(expense_type, "") or str(claim_groups[0].scene_label or "").strip() + if group_code: + return INFERRED_REASON_LABELS.get(group_code, "") or str(claim_groups[0].scene_label or "").strip() + return "" + @staticmethod def _resolve_review_missing_slot_keys( payload: UserAgentRequest, @@ -2087,6 +2306,7 @@ class UserAgentService: slot_cards: list[UserAgentReviewSlotCard], ) -> list[str]: required_keys = {item.key for item in slot_cards if item.required} + slot_map = {item.key: item for item in slot_cards} missing_keys = { item.key for item in slot_cards @@ -2094,7 +2314,15 @@ class UserAgentService: } for key in payload.ontology.missing_slots: normalized_key = str(key or "").strip() - if normalized_key and normalized_key in required_keys: + if ( + normalized_key + and normalized_key in required_keys + and ( + normalized_key not in slot_map + or slot_map[normalized_key].status == "missing" + or not str(slot_map[normalized_key].value).strip() + ) + ): missing_keys.add(normalized_key) ordered_keys: list[str] = [] @@ -2257,35 +2485,104 @@ class UserAgentService: def _extract_document_fields(self, item: dict[str, object]) -> dict[str, str]: raw_fields = item.get("document_fields") + normalized_fields: dict[str, str] = {} if isinstance(raw_fields, list): - normalized_fields: dict[str, str] = {} for field in raw_fields: if not isinstance(field, dict): continue + key = str(field.get("key") or "").strip() label = str(field.get("label") or "").strip() value = str(field.get("value") or "").strip() - if label and value: - normalized_fields[label] = value - if normalized_fields: - return normalized_fields + if not value: + continue + normalized_label = self._normalize_document_field_label(key=key, label=label) + display_label = normalized_label or label + normalized_value = self._normalize_document_field_value( + label=display_label, + value=value, + ) + if display_label and normalized_value: + normalized_fields.setdefault(display_label, normalized_value) text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() - fields: dict[str, str] = {} - amount_match = AMOUNT_TEXT_PATTERN.search(text) - if amount_match: - fields["金额"] = f"{amount_match.group(1)}元" + amount_value = self._extract_amount_text_from_value(text) + if amount_value and "金额" not in normalized_fields: + normalized_fields["金额"] = amount_value date_match = DATE_TEXT_PATTERN.search(text) - if date_match: - fields["时间"] = date_match.group(1) + if date_match and "时间" not in normalized_fields: + normalized_fields["时间"] = date_match.group(1) - merchant = self._extract_document_merchant_name(item) - if merchant: - fields["商户/酒店"] = merchant - return fields + merchant = self._extract_document_merchant_name_from_text(text) + if merchant and "商户/酒店" not in normalized_fields: + normalized_fields["商户/酒店"] = merchant + return normalized_fields @staticmethod - def _extract_document_merchant_name(item: dict[str, object]) -> str: + def _normalize_document_field_label(*, key: str, label: str) -> str: + compact_key = str(key or "").strip().lower().replace("_", "") + compact_label = str(label or "").replace(" ", "") + if compact_key in { + "amount", + "totalamount", + "paymentamount", + "paidamount", + "actualamount", + } or any( + token in compact_label + for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额") + ): + return "金额" + if compact_key in {"date", "time", "issuedat", "invoicedate"} or any( + token in compact_label for token in ("日期", "时间", "开票日期", "发生时间") + ): + return "时间" + if compact_key in {"merchant", "merchantname", "sellername", "vendorname"} or any( + token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方") + ): + return "商户/酒店" + return label + + def _normalize_document_field_value(self, *, label: str, value: str) -> str: + normalized_label = str(label or "").strip() + raw_value = str(value or "").strip() + if not normalized_label or not raw_value: + return "" + if normalized_label == "金额": + return self._extract_amount_text_from_value(raw_value) or raw_value + if normalized_label == "时间": + match = DATE_TEXT_PATTERN.search(raw_value) + return match.group(1) if match else raw_value + return raw_value + + def _extract_amount_text_from_value(self, value: str) -> str: + raw_value = str(value or "").strip() + if not raw_value: + return "" + best_amount: Decimal | None = None + for pattern in (DOCUMENT_AMOUNT_PATTERN, DOCUMENT_CURRENCY_AMOUNT_PATTERN, AMOUNT_TEXT_PATTERN): + for match in pattern.finditer(raw_value): + try: + candidate = Decimal(str(match.group(1)).replace(",", ".")) + except (InvalidOperation, TypeError): + continue + if candidate <= Decimal("0.00"): + continue + if best_amount is None or candidate > best_amount: + best_amount = candidate + if best_amount is None: + return "" + return f"{best_amount.quantize(Decimal('0.01')):.2f}元" + + def _extract_document_merchant_name(self, item: dict[str, object]) -> str: + fields = self._extract_document_fields(item) + merchant = str(fields.get("商户/酒店") or "").strip() + if merchant: + return merchant text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() + return self._extract_document_merchant_name_from_text(text) + + @staticmethod + def _extract_document_merchant_name_from_text(text: str) -> str: for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"): if keyword in text: return keyword @@ -2297,7 +2594,8 @@ class UserAgentService: if item.label != "金额": continue try: - return float(str(item.value).replace("元", "").strip()) + normalized_value = str(item.value).replace("元", "").replace("¥", "").replace("¥", "").strip() + return float(normalized_value) except ValueError: return 0.0 return 0.0 @@ -2315,7 +2613,7 @@ class UserAgentService: total = 0.0 for item in ocr_documents: fields = self._extract_document_fields(item) - amount_text = str(fields.get("金额") or "").replace("元", "").strip() + amount_text = str(fields.get("金额") or "").replace("元", "").replace("¥", "").replace("¥", "").strip() if not amount_text: continue try: diff --git a/server/tests/test_user_agent_service.py b/server/tests/test_user_agent_service.py index b398e60..473fd31 100644 --- a/server/tests/test_user_agent_service.py +++ b/server/tests/test_user_agent_service.py @@ -46,7 +46,7 @@ def test_user_agent_query_returns_readable_answer_and_actions() -> None: assert len(response.suggested_actions) >= 1 -def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None: +def test_user_agent_returns_readable_query_answer_when_runtime_model_is_skipped(monkeypatch) -> None: session_factory = build_session_factory() with session_factory() as db: ontology = SemanticOntologyService(db).parse( @@ -56,11 +56,7 @@ def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> ) ) service = UserAgentService(db) - monkeypatch.setattr( - service, - "_generate_answer_with_model", - lambda *args, **kwargs: "这是模型回答", - ) + monkeypatch.setattr(service, "_generate_answer_with_model", lambda *args, **kwargs: "这是模型回答") response = service.respond( UserAgentRequest( @@ -72,7 +68,8 @@ def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> ) ) - assert response.answer == "这是模型回答" + assert "共 2 笔" in response.answer + assert "8800.00" in response.answer def test_user_agent_sanitizes_model_thinking_blocks() -> None: @@ -144,7 +141,7 @@ def test_user_agent_guides_implicit_expense_draft_request() -> None: assert response.review_payload is not None assert response.answer == response.review_payload.body_message - assert response.review_payload.intent_summary.startswith("我理解你这次想报销业务招待费。") + assert response.review_payload.intent_summary.startswith("识别到您希望报销一笔“业务招待费”费用。") assert response.review_payload.missing_slots == ["客户名称", "参与人员", "票据附件"] assert [item.action_type for item in response.review_payload.confirmation_actions] == [ "cancel_review", @@ -187,7 +184,102 @@ def test_user_agent_guides_narrative_with_day_before_yesterday() -> None: slot_map = {item.key: item for item in response.review_payload.slot_cards} assert slot_map["time_range"].raw_value == "前天" assert slot_map["time_range"].value == "2026-05-11" - assert "时间:2026-05-11" in response.review_payload.intent_summary + assert "时间为 2026-05-11" in response.review_payload.intent_summary + + +def test_user_agent_attachment_only_upload_uses_generic_scene_reason_without_fabrication() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。", + user_id="pytest", + context_json={ + "attachment_names": ["didi-trip.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "didi-trip.png", + "summary": "滴滴出行 订单金额 32 元", + "text": "滴滴出行 订单金额 32 元", + "document_type": "taxi_receipt", + "scene_code": "transport", + } + ], + "user_input_text": "", + }, + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我上传了 1 份票据,请结合附件名称给出报销建议并尽量生成草稿。\n附件名称:didi-trip.png", + ontology=ontology, + context_json={ + "attachment_names": ["didi-trip.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "didi-trip.png", + "summary": "滴滴出行 订单金额 32 元", + "text": "滴滴出行 订单金额 32 元", + "document_type": "taxi_receipt", + "scene_code": "transport", + } + ], + "user_input_text": "", + }, + tool_payload={"draft_only": True}, + ) + ) + + assert response.review_payload is not None + slot_map = {item.key: item for item in response.review_payload.slot_cards} + assert slot_map["reason"].value == "交通出行" + assert slot_map["reason"].status == "inferred" + + +def test_user_agent_transport_flow_infers_reason_and_does_not_require_location_or_merchant() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了交通票据,帮我生成报销草稿", + user_id="pytest", + ) + ) + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我上传了交通票据,帮我生成报销草稿", + ontology=ontology, + context_json={ + "attachment_names": ["didi-trip.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "didi-trip.png", + "summary": "滴滴出行 支付金额 32 元", + "text": "滴滴出行 支付金额 32 元", + "document_type": "taxi_receipt", + "scene_code": "transport", + "scene_label": "交通票据", + } + ], + }, + tool_payload={"draft_only": True}, + ) + ) + + assert response.review_payload is not None + slot_map = {item.key: item for item in response.review_payload.slot_cards} + assert slot_map["reason"].value == "交通出行" + assert slot_map["reason"].status == "inferred" + assert "酒店/商户" not in response.review_payload.missing_slots + assert "地点" not in response.review_payload.missing_slots + assert "事由说明" not in response.review_payload.missing_slots def test_user_agent_risk_response_includes_rule_citations() -> None: @@ -347,7 +439,238 @@ def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> N "save_draft", ] assert any(item.scene_label == "业务招待费" for item in response.review_payload.document_cards) - assert f"时间:{yesterday}" in response.review_payload.intent_summary + assert f"时间为 {yesterday}" in response.review_payload.intent_summary slot_map = {item.key: item for item in response.review_payload.slot_cards} assert slot_map["time_range"].value == yesterday assert slot_map["time_range"].raw_value == "昨天" + + +def test_user_agent_sums_multi_document_amounts_from_synonym_fields() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了两张交通票据,帮我生成报销草稿", + user_id="pytest", + context_json={ + "attachment_names": ["滴滴行程单.png", "停车票.jpg"], + "attachment_count": 2, + "ocr_documents": [ + { + "filename": "滴滴行程单.png", + "summary": "滴滴出行电子行程单", + "text": "滴滴出行 订单金额 ¥32.50", + "avg_score": 0.94, + "document_fields": [ + {"key": "amount", "label": "支付金额", "value": "32.50"}, + ], + "warnings": [], + }, + { + "filename": "停车票.jpg", + "summary": "停车票", + "text": "停车费 合计 18 元", + "avg_score": 0.92, + "document_fields": [ + {"key": "total_amount", "label": "合计金额", "value": "18"}, + ], + "warnings": [], + }, + ], + }, + ) + ) + + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我上传了两张交通票据,帮我生成报销草稿", + ontology=ontology, + context_json={ + "attachment_names": ["滴滴行程单.png", "停车票.jpg"], + "attachment_count": 2, + "ocr_documents": [ + { + "filename": "滴滴行程单.png", + "summary": "滴滴出行电子行程单", + "text": "滴滴出行 订单金额 ¥32.50", + "avg_score": 0.94, + "document_fields": [ + {"key": "amount", "label": "支付金额", "value": "32.50"}, + ], + "warnings": [], + }, + { + "filename": "停车票.jpg", + "summary": "停车票", + "text": "停车费 合计 18 元", + "avg_score": 0.92, + "document_fields": [ + {"key": "total_amount", "label": "合计金额", "value": "18"}, + ], + "warnings": [], + }, + ], + }, + tool_payload={"draft_only": True}, + ) + ) + + assert response.review_payload is not None + slot_map = {item.key: item for item in response.review_payload.slot_cards} + assert slot_map["amount"].value == "50.50元" + document_field_labels = [ + field.label + for card in response.review_payload.document_cards + for field in card.fields + ] + assert "金额" in document_field_labels + + +def test_user_agent_prefers_larger_decimal_amount_from_ocr_text_candidates() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了打车票据,帮我生成报销草稿", + user_id="pytest", + context_json={ + "attachment_names": ["滴滴行程单.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "滴滴行程单.png", + "summary": "滴滴出行电子行程单", + "text": "滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678", + "avg_score": 0.94, + "warnings": [], + }, + ], + }, + ) + ) + + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我上传了打车票据,帮我生成报销草稿", + ontology=ontology, + context_json={ + "attachment_names": ["滴滴行程单.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "滴滴行程单.png", + "summary": "滴滴出行电子行程单", + "text": "滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678", + "avg_score": 0.94, + "warnings": [], + }, + ], + }, + tool_payload={"draft_only": True}, + ) + ) + + assert response.review_payload is not None + slot_map = {item.key: item for item in response.review_payload.slot_cards} + assert slot_map["amount"].value == "13.40元" + + +def test_user_agent_review_payload_keeps_document_preview_data() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了打车票据,帮我生成报销草稿", + user_id="pytest", + context_json={ + "attachment_names": ["滴滴行程单.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "滴滴行程单.png", + "summary": "滴滴出行电子行程单", + "text": "滴滴出行 实付 13.4 元", + "avg_score": 0.94, + "preview_kind": "image", + "preview_data_url": "data:image/png;base64,ZmFrZQ==", + "warnings": [], + }, + ], + }, + ) + ) + + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我上传了打车票据,帮我生成报销草稿", + ontology=ontology, + context_json={ + "attachment_names": ["滴滴行程单.png"], + "attachment_count": 1, + "ocr_documents": [ + { + "filename": "滴滴行程单.png", + "summary": "滴滴出行电子行程单", + "text": "滴滴出行 实付 13.4 元", + "avg_score": 0.94, + "preview_kind": "image", + "preview_data_url": "data:image/png;base64,ZmFrZQ==", + "warnings": [], + }, + ], + }, + tool_payload={"draft_only": True}, + ) + ) + + assert response.review_payload is not None + assert response.review_payload.document_cards[0].preview_kind == "image" + assert response.review_payload.document_cards[0].preview_data_url.startswith("data:image/png;base64,") + + +def test_user_agent_prompts_existing_draft_association_choice_for_multi_documents() -> None: + session_factory = build_session_factory() + with session_factory() as db: + ontology = SemanticOntologyService(db).parse( + OntologyParseRequest( + query="我上传了两张票据,帮我生成报销草稿", + user_id="pytest", + ) + ) + + response = UserAgentService(db).respond( + UserAgentRequest( + run_id=ontology.run_id, + user_id="pytest", + message="我上传了两张票据,帮我生成报销草稿", + ontology=ontology, + context_json={ + "attachment_names": ["滴滴行程单.png", "餐饮发票.jpg"], + "attachment_count": 2, + "ocr_documents": [ + {"filename": "滴滴行程单.png", "summary": "滴滴出行 金额 32 元", "text": "滴滴出行 金额 32 元"}, + {"filename": "餐饮发票.jpg", "summary": "餐饮发票 金额 68 元", "text": "餐饮发票 金额 68 元"}, + ], + }, + tool_payload={ + "pending_association_decision": True, + "association_candidate_claim_no": "EXP-202605-008", + }, + ) + ) + + assert response.review_payload is not None + assert response.review_payload.can_proceed is False + assert [item.action_type for item in response.review_payload.confirmation_actions] == [ + "cancel_review", + "edit_review", + "link_to_existing_draft", + "create_new_claim_from_documents", + ] + assert "EXP-202605-008" in response.answer