from __future__ import annotations import re from decimal import Decimal, InvalidOperation from typing import Mapping from app.schemas.user_agent import UserAgentRequest, UserAgentReviewDocumentCard DEFAULT_GROUP_SCENE_LABELS = { "travel": "差旅费", "entertainment": "业务招待费", "meal": "业务招待费", "transport": "交通费", "hotel": "住宿费", "office": "办公用品费", "training": "培训费", "communication": "通讯费", "welfare": "福利费", "other": "其他费用", } DOCUMENT_SCENE_LABELS = { "flight_itinerary": "机票/航班行程单", "train_ticket": "火车/高铁票", "ship_ticket": "轮船票", "travel_ticket": "交通出行票据", "hotel_invoice": "酒店住宿票据", "taxi_receipt": "出租车/网约车票据", "transport_receipt": "乘车票据", "parking_toll_receipt": "停车/通行费票据", "meal_receipt": "餐饮发票", "office_invoice": "文具/办公用品发票", "meeting_invoice": "会议/会务票据", "training_invoice": "培训票据", "other": "其他票据", } DOCUMENT_DATE_TEXT_PATTERN = re.compile( r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?(?:\s*[T ]?\s*(?:[01]?\d|2[0-3])[::][0-5]\d)?)" ) DOCUMENT_AMOUNT_TEXT_PATTERN = re.compile( r"(\d+(?:\.\d+)?)\s*(?:万元|万员|万圆|万园|万块|万元整|元整|块钱|块|元|员|圆|园|万)" ) DOCUMENT_AMOUNT_PATTERN = re.compile( r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)" r"[::\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)" ) DOCUMENT_CURRENCY_AMOUNT_PATTERN = re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)") class UserAgentDocumentService: """集中处理票据分类和 OCR 字段抽取,避免主服务继续膨胀。""" def __init__(self, *, group_scene_labels: Mapping[str, str] | None = None) -> None: self._group_scene_labels = dict(group_scene_labels or DEFAULT_GROUP_SCENE_LABELS) def classify_document( self, item: dict[str, object], *, expense_type_code: str = "", has_customer: bool = False, ) -> dict[str, str]: provided_type = str(item.get("document_type") or "").strip().lower() normalized_expense_type = str(expense_type_code or "").strip().lower() if provided_type: if provided_type in {"flight_itinerary", "train_ticket", "ship_ticket"}: return { "document_type": provided_type, "expense_type": "travel", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS.get(provided_type, "交通出行票据"), } if provided_type == "hotel_invoice": return { "document_type": provided_type, "expense_type": "hotel", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["hotel_invoice"], } if provided_type in {"taxi_receipt", "transport_receipt", "parking_toll_receipt"}: return { "document_type": provided_type, "expense_type": "transport", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS.get(provided_type, "乘车票据"), } if provided_type == "meal_receipt": group_code = "meal" return { "document_type": provided_type, "expense_type": group_code, "group_code": group_code, "scene_label": DOCUMENT_SCENE_LABELS["meal_receipt"], } if provided_type == "office_invoice": return { "document_type": provided_type, "expense_type": "office", "group_code": "office", "scene_label": DOCUMENT_SCENE_LABELS["office_invoice"], } if provided_type == "meeting_invoice": return { "document_type": provided_type, "expense_type": "meeting", "group_code": "meeting", "scene_label": DOCUMENT_SCENE_LABELS["meeting_invoice"], } if provided_type == "training_invoice": return { "document_type": provided_type, "expense_type": "training", "group_code": "training", "scene_label": DOCUMENT_SCENE_LABELS["training_invoice"], } text = " ".join( [ str(item.get("filename") or ""), str(item.get("summary") or ""), str(item.get("text") or ""), ] ).lower() compact = text.replace(" ", "") if any(keyword in compact for keyword in ("火车", "高铁", "动车", "铁路", "车次")): return { "document_type": "train_ticket", "expense_type": "travel", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["train_ticket"], } if any(keyword in compact for keyword in ("过路费", "停车", "通行费", "收费站")): return { "document_type": "parking_toll_receipt", "expense_type": "transport", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["parking_toll_receipt"], } if any(keyword in compact for keyword in ("打车", "出租车", "滴滴", "网约车", "叫车", "车费", "车资", "的士")): return { "document_type": "taxi_receipt", "expense_type": "transport", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["taxi_receipt"], } if any(keyword in compact for keyword in ("乘车", "用车")): return { "document_type": "transport_receipt", "expense_type": "transport", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["transport_receipt"], } if any(keyword in compact for keyword in ("机票", "航班", "登机", "航空", "客票")): return { "document_type": "flight_itinerary", "expense_type": "travel", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["flight_itinerary"], } if any(keyword in compact for keyword in ("轮船", "船票", "客轮", "渡轮", "航运")): return { "document_type": "ship_ticket", "expense_type": "travel", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["ship_ticket"], } if any(keyword in compact for keyword in ("酒店", "住宿", "宾馆")): return { "document_type": "hotel_invoice", "expense_type": "hotel", "group_code": "travel", "scene_label": DOCUMENT_SCENE_LABELS["hotel_invoice"], } if any(keyword in compact for keyword in ("餐", "饭店", "酒楼", "酒家", "餐饮", "meal")): group_code = "meal" return { "document_type": "meal_receipt", "expense_type": group_code, "group_code": group_code, "scene_label": DOCUMENT_SCENE_LABELS["meal_receipt"], } if any(keyword in compact for keyword in ("办公用品", "文具", "耗材", "办公耗材", "打印纸", "键盘", "鼠标", "白板", "墨盒", "硒鼓")): return { "document_type": "office_invoice", "expense_type": "office", "group_code": "office", "scene_label": DOCUMENT_SCENE_LABELS["office_invoice"], } return { "document_type": "other", "expense_type": normalized_expense_type or "other", "group_code": self.normalize_group_code(normalized_expense_type or "other"), "scene_label": "其他票据", } @staticmethod def normalize_group_code(expense_type_code: str) -> str: if expense_type_code in {"travel", "hotel", "transport"}: return "travel" if expense_type_code in {"entertainment", "meal", "office", "training", "communication", "welfare"}: return expense_type_code return "other" def extract_document_fields(self, item: dict[str, object]) -> dict[str, str]: raw_fields = item.get("document_fields") normalized_fields: dict[str, str] = {} document_type = str(item.get("document_type") or "").strip().lower() if isinstance(raw_fields, list): for field in raw_fields: if not isinstance(field, dict): continue key = str(field.get("key") or "").strip() label = str(field.get("label") or "").strip() value = str(field.get("value") or "").strip() if not value: continue normalized_label = self.normalize_document_field_label(key=key, label=label) display_label = normalized_label or label display_label = self.resolve_document_time_display_label( document_type=document_type, key=key, label=label, normalized_label=display_label, ) normalized_value = self.normalize_document_field_value( label=display_label, value=value, ) if display_label == "商户/酒店" and not self.is_hotel_document_item(item): continue if display_label and normalized_value: normalized_fields.setdefault(display_label, normalized_value) text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() amount_value = self.extract_amount_text_from_value(text) if amount_value and "金额" not in normalized_fields: normalized_fields["金额"] = amount_value date_match = DOCUMENT_DATE_TEXT_PATTERN.search(text) if date_match and "时间" not in normalized_fields: time_label = self.resolve_document_time_display_label( document_type=document_type, key="date", label="日期", normalized_label="时间", ) normalized_fields[time_label] = date_match.group(1) merchant = self.extract_document_merchant_name_from_text(text) if self.is_hotel_document_item(item) else "" if merchant and "商户/酒店" not in normalized_fields: normalized_fields["商户/酒店"] = merchant return normalized_fields @staticmethod def resolve_document_time_display_label( *, document_type: str, key: str, label: str, normalized_label: str, ) -> str: if normalized_label != "时间": return normalized_label label_by_type = { "train_ticket": "列车出发时间", "flight_itinerary": "起飞日期", "taxi_receipt": "乘车时间", "transport_receipt": "乘车时间", "parking_toll_receipt": "通行日期", } normalized_type = str(document_type or "").strip().lower() if normalized_type not in label_by_type: return normalized_label compact_key = str(key or "").strip().lower().replace("_", "") compact_label = str(label or "").replace(" ", "") if compact_key in {"date", "time", "issuedat", "issuedate", "invoicedate"}: return label_by_type[normalized_type] if any(token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")): return label_by_type[normalized_type] return normalized_label @staticmethod def normalize_document_field_label(*, key: str, label: str) -> str: compact_key = str(key or "").strip().lower().replace("_", "") compact_label = str(label or "").replace(" ", "") if compact_key in { "amount", "totalamount", "paymentamount", "paidamount", "actualamount", } or any( token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额") ): return "金额" if compact_key in {"date", "time", "issuedat", "invoicedate"} or any( token in compact_label for token in ("日期", "时间", "开票日期", "发生时间") ): return "时间" if compact_key in {"merchant", "merchantname", "sellername", "vendorname"} or any( token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方") ): return "商户/酒店" return label def normalize_document_field_value(self, *, label: str, value: str) -> str: normalized_label = str(label or "").strip() raw_value = str(value or "").strip() if not normalized_label or not raw_value: return "" if normalized_label == "金额": return self.extract_amount_text_from_value(raw_value) or raw_value if normalized_label in {"时间", "出发日期", "列车出发时间", "起飞日期", "乘车时间", "通行日期"}: match = DOCUMENT_DATE_TEXT_PATTERN.search(raw_value) return match.group(1) if match else raw_value return raw_value def extract_amount_text_from_value(self, value: str) -> str: raw_value = str(value or "").strip() if not raw_value: return "" best_amount: Decimal | None = None for pattern in (DOCUMENT_AMOUNT_PATTERN, DOCUMENT_CURRENCY_AMOUNT_PATTERN, DOCUMENT_AMOUNT_TEXT_PATTERN): for match in pattern.finditer(raw_value): try: candidate = Decimal(str(match.group(1)).replace(",", ".")) except (InvalidOperation, TypeError): continue if candidate <= Decimal("0.00"): continue if best_amount is None or candidate > best_amount: best_amount = candidate if best_amount is None: return "" return f"{best_amount.quantize(Decimal('0.01')):.2f}元" def extract_document_merchant_name(self, item: dict[str, object]) -> str: fields = self.extract_document_fields(item) merchant = str(fields.get("商户/酒店") or "").strip() if merchant: return merchant if not self.is_hotel_document_item(item): return "" text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() return self.extract_document_merchant_name_from_text(text) @staticmethod def is_hotel_document_item(item: dict[str, object]) -> bool: document_type = str(item.get("document_type") or "").strip().lower() scene_code = str(item.get("scene_code") or "").strip().lower() scene_label = str(item.get("scene_label") or "").strip() suggested_expense_type = str(item.get("suggested_expense_type") or "").strip().lower() return ( document_type == "hotel_invoice" or scene_code == "hotel" or suggested_expense_type == "hotel" or "住宿" in scene_label or "酒店" in scene_label ) @staticmethod def extract_document_merchant_name_from_text(text: str) -> str: for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"): if keyword in text: return keyword return "" @staticmethod def extract_amount_from_card(card: UserAgentReviewDocumentCard) -> float: for item in card.fields: if item.label != "金额": continue try: normalized_value = str(item.value).replace("元", "").replace("¥", "").replace("¥", "").strip() return float(normalized_value) except ValueError: return 0.0 return 0.0 @staticmethod def resolve_amount_value(payload: UserAgentRequest) -> float: for item in payload.ontology.entities: if item.type == "amount" and item.role != "threshold": try: return float(item.normalized_value) except ValueError: return 0.0 return 0.0 def sum_ocr_amounts(self, ocr_documents: list[dict[str, object]]) -> float: total = 0.0 for item in ocr_documents: fields = self.extract_document_fields(item) amount_text = str(fields.get("金额") or "").replace("元", "").replace("¥", "").replace("¥", "").strip() if not amount_text: continue try: total += float(amount_text) except ValueError: continue return total def infer_expense_type_from_documents( self, ocr_documents: list[dict[str, object]], *, expense_type_code: str = "", has_customer: bool = False, ) -> str: labels: list[str] = [] for item in ocr_documents: classified = self.classify_document( item, expense_type_code=expense_type_code, has_customer=has_customer, ) label = self._group_scene_labels.get(classified["group_code"], "") if label and label not in labels: labels.append(label) return " + ".join(labels[:3])