Files
X-Financial/server/src/app/services/user_agent_documents.py
2026-05-22 10:42:31 +08:00

381 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
from decimal import Decimal, InvalidOperation
from typing import Mapping
from app.schemas.user_agent import UserAgentRequest, UserAgentReviewDocumentCard
DEFAULT_GROUP_SCENE_LABELS = {
"travel": "差旅费",
"entertainment": "业务招待费",
"meal": "伙食费",
"transport": "交通费",
"hotel": "住宿费",
"office": "办公费",
"training": "培训费",
"communication": "通讯费",
"welfare": "福利费",
"other": "其他费用",
}
DOCUMENT_DATE_TEXT_PATTERN = re.compile(
r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?(?:\s*[T ]?\s*(?:[01]?\d|2[0-3])[:][0-5]\d)?)"
)
DOCUMENT_AMOUNT_TEXT_PATTERN = re.compile(
r"(\d+(?:\.\d+)?)\s*(?:万元|万员|万圆|万园|万块|万元整|元整|块钱|块|元|员|圆|园|万)"
)
DOCUMENT_AMOUNT_PATTERN = re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
)
DOCUMENT_CURRENCY_AMOUNT_PATTERN = re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)")
class UserAgentDocumentService:
"""集中处理票据分类和 OCR 字段抽取,避免主服务继续膨胀。"""
def __init__(self, *, group_scene_labels: Mapping[str, str] | None = None) -> None:
self._group_scene_labels = dict(group_scene_labels or DEFAULT_GROUP_SCENE_LABELS)
def classify_document(
self,
item: dict[str, object],
*,
expense_type_code: str = "",
has_customer: bool = False,
) -> dict[str, str]:
provided_type = str(item.get("document_type") or "").strip().lower()
normalized_expense_type = str(expense_type_code or "").strip().lower()
if provided_type:
if provided_type in {"flight_itinerary", "train_ticket"}:
return {
"document_type": provided_type,
"expense_type": "travel",
"group_code": "travel",
"scene_label": "差旅票据",
}
if provided_type == "hotel_invoice":
return {
"document_type": provided_type,
"expense_type": "hotel",
"group_code": "travel",
"scene_label": "住宿票据",
}
if provided_type in {"taxi_receipt", "parking_toll_receipt"}:
return {
"document_type": provided_type,
"expense_type": "transport",
"group_code": "travel",
"scene_label": "交通票据",
}
if provided_type == "meal_receipt":
group_code = "entertainment" if normalized_expense_type == "entertainment" or has_customer else "meal"
return {
"document_type": provided_type,
"expense_type": group_code,
"group_code": group_code,
"scene_label": "餐饮票据",
}
if provided_type == "office_invoice":
return {
"document_type": provided_type,
"expense_type": "office",
"group_code": "office",
"scene_label": "办公用品票据",
}
if provided_type == "meeting_invoice":
return {
"document_type": provided_type,
"expense_type": "meeting",
"group_code": "meeting",
"scene_label": "会务票据",
}
if provided_type == "training_invoice":
return {
"document_type": provided_type,
"expense_type": "training",
"group_code": "training",
"scene_label": "培训票据",
}
text = " ".join(
[
str(item.get("filename") or ""),
str(item.get("summary") or ""),
str(item.get("text") or ""),
]
).lower()
compact = text.replace(" ", "")
if any(keyword in compact for keyword in ("机票", "航班", "火车", "高铁", "行程单")):
return {
"document_type": "travel_ticket",
"expense_type": "travel",
"group_code": "travel",
"scene_label": "差旅票据",
}
if any(keyword in compact for keyword in ("酒店", "住宿", "宾馆")):
return {
"document_type": "hotel_invoice",
"expense_type": "hotel",
"group_code": "travel",
"scene_label": "住宿票据",
}
if any(keyword in compact for keyword in ("打车", "出租车", "滴滴", "网约车", "乘车", "用车", "叫车", "车费", "车资", "的士", "过路费", "停车")):
return {
"document_type": "transport_receipt",
"expense_type": "transport",
"group_code": "travel",
"scene_label": "交通票据",
}
if any(keyword in compact for keyword in ("", "饭店", "酒楼", "酒家", "餐饮", "meal")):
group_code = "entertainment" if normalized_expense_type == "entertainment" or has_customer else "meal"
return {
"document_type": "meal_receipt",
"expense_type": group_code,
"group_code": group_code,
"scene_label": "餐饮票据",
}
if any(keyword in compact for keyword in ("办公用品", "文具", "耗材", "办公耗材", "打印纸", "键盘", "鼠标", "白板", "墨盒", "硒鼓")):
return {
"document_type": "other",
"expense_type": "office",
"group_code": "office",
"scene_label": "办公用品票据",
}
return {
"document_type": "other",
"expense_type": normalized_expense_type or "other",
"group_code": self.normalize_group_code(normalized_expense_type or "other"),
"scene_label": "其他票据",
}
@staticmethod
def normalize_group_code(expense_type_code: str) -> str:
if expense_type_code in {"travel", "hotel", "transport"}:
return "travel"
if expense_type_code in {"entertainment", "meal", "office", "training", "communication", "welfare"}:
return expense_type_code
return "other"
def extract_document_fields(self, item: dict[str, object]) -> dict[str, str]:
raw_fields = item.get("document_fields")
normalized_fields: dict[str, str] = {}
document_type = str(item.get("document_type") or "").strip().lower()
if isinstance(raw_fields, list):
for field in raw_fields:
if not isinstance(field, dict):
continue
key = str(field.get("key") or "").strip()
label = str(field.get("label") or "").strip()
value = str(field.get("value") or "").strip()
if not value:
continue
normalized_label = self.normalize_document_field_label(key=key, label=label)
display_label = normalized_label or label
display_label = self.resolve_document_time_display_label(
document_type=document_type,
key=key,
label=label,
normalized_label=display_label,
)
normalized_value = self.normalize_document_field_value(
label=display_label,
value=value,
)
if display_label == "商户/酒店" and not self.is_hotel_document_item(item):
continue
if display_label and normalized_value:
normalized_fields.setdefault(display_label, normalized_value)
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
amount_value = self.extract_amount_text_from_value(text)
if amount_value and "金额" not in normalized_fields:
normalized_fields["金额"] = amount_value
date_match = DOCUMENT_DATE_TEXT_PATTERN.search(text)
if date_match and "时间" not in normalized_fields:
time_label = self.resolve_document_time_display_label(
document_type=document_type,
key="date",
label="日期",
normalized_label="时间",
)
normalized_fields[time_label] = date_match.group(1)
merchant = self.extract_document_merchant_name_from_text(text) if self.is_hotel_document_item(item) else ""
if merchant and "商户/酒店" not in normalized_fields:
normalized_fields["商户/酒店"] = merchant
return normalized_fields
@staticmethod
def resolve_document_time_display_label(
*,
document_type: str,
key: str,
label: str,
normalized_label: str,
) -> str:
if normalized_label != "时间":
return normalized_label
label_by_type = {
"train_ticket": "列车出发时间",
"flight_itinerary": "起飞日期",
"taxi_receipt": "乘车时间",
"transport_receipt": "乘车时间",
"parking_toll_receipt": "通行日期",
}
normalized_type = str(document_type or "").strip().lower()
if normalized_type not in label_by_type:
return normalized_label
compact_key = str(key or "").strip().lower().replace("_", "")
compact_label = str(label or "").replace(" ", "")
if compact_key in {"date", "time", "issuedat", "issuedate", "invoicedate"}:
return label_by_type[normalized_type]
if any(token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")):
return label_by_type[normalized_type]
return normalized_label
@staticmethod
def normalize_document_field_label(*, key: str, label: str) -> str:
compact_key = str(key or "").strip().lower().replace("_", "")
compact_label = str(label or "").replace(" ", "")
if compact_key in {
"amount",
"totalamount",
"paymentamount",
"paidamount",
"actualamount",
} or any(
token in compact_label
for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return "金额"
if compact_key in {"date", "time", "issuedat", "invoicedate"} or any(
token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")
):
return "时间"
if compact_key in {"merchant", "merchantname", "sellername", "vendorname"} or any(
token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方")
):
return "商户/酒店"
return label
def normalize_document_field_value(self, *, label: str, value: str) -> str:
normalized_label = str(label or "").strip()
raw_value = str(value or "").strip()
if not normalized_label or not raw_value:
return ""
if normalized_label == "金额":
return self.extract_amount_text_from_value(raw_value) or raw_value
if normalized_label in {"时间", "出发日期", "列车出发时间", "起飞日期", "乘车时间", "通行日期"}:
match = DOCUMENT_DATE_TEXT_PATTERN.search(raw_value)
return match.group(1) if match else raw_value
return raw_value
def extract_amount_text_from_value(self, value: str) -> str:
raw_value = str(value or "").strip()
if not raw_value:
return ""
best_amount: Decimal | None = None
for pattern in (DOCUMENT_AMOUNT_PATTERN, DOCUMENT_CURRENCY_AMOUNT_PATTERN, DOCUMENT_AMOUNT_TEXT_PATTERN):
for match in pattern.finditer(raw_value):
try:
candidate = Decimal(str(match.group(1)).replace(",", "."))
except (InvalidOperation, TypeError):
continue
if candidate <= Decimal("0.00"):
continue
if best_amount is None or candidate > best_amount:
best_amount = candidate
if best_amount is None:
return ""
return f"{best_amount.quantize(Decimal('0.01')):.2f}"
def extract_document_merchant_name(self, item: dict[str, object]) -> str:
fields = self.extract_document_fields(item)
merchant = str(fields.get("商户/酒店") or "").strip()
if merchant:
return merchant
if not self.is_hotel_document_item(item):
return ""
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
return self.extract_document_merchant_name_from_text(text)
@staticmethod
def is_hotel_document_item(item: dict[str, object]) -> bool:
document_type = str(item.get("document_type") or "").strip().lower()
scene_code = str(item.get("scene_code") or "").strip().lower()
scene_label = str(item.get("scene_label") or "").strip()
suggested_expense_type = str(item.get("suggested_expense_type") or "").strip().lower()
return (
document_type == "hotel_invoice"
or scene_code == "hotel"
or suggested_expense_type == "hotel"
or "住宿" in scene_label
or "酒店" in scene_label
)
@staticmethod
def extract_document_merchant_name_from_text(text: str) -> str:
for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"):
if keyword in text:
return keyword
return ""
@staticmethod
def extract_amount_from_card(card: UserAgentReviewDocumentCard) -> float:
for item in card.fields:
if item.label != "金额":
continue
try:
normalized_value = str(item.value).replace("", "").replace("", "").replace("¥", "").strip()
return float(normalized_value)
except ValueError:
return 0.0
return 0.0
@staticmethod
def resolve_amount_value(payload: UserAgentRequest) -> float:
for item in payload.ontology.entities:
if item.type == "amount" and item.role != "threshold":
try:
return float(item.normalized_value)
except ValueError:
return 0.0
return 0.0
def sum_ocr_amounts(self, ocr_documents: list[dict[str, object]]) -> float:
total = 0.0
for item in ocr_documents:
fields = self.extract_document_fields(item)
amount_text = str(fields.get("金额") or "").replace("", "").replace("", "").replace("¥", "").strip()
if not amount_text:
continue
try:
total += float(amount_text)
except ValueError:
continue
return total
def infer_expense_type_from_documents(
self,
ocr_documents: list[dict[str, object]],
*,
expense_type_code: str = "",
has_customer: bool = False,
) -> str:
labels: list[str] = []
for item in ocr_documents:
classified = self.classify_document(
item,
expense_type_code=expense_type_code,
has_customer=has_customer,
)
label = self._group_scene_labels.get(classified["group_code"], "")
if label and label not in labels:
labels.append(label)
return " + ".join(labels[:3])