Files
X-Financial/server/src/app/services/document_intelligence.py

642 lines
24 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
import re
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
@dataclass(frozen=True, slots=True)
class DocumentField:
key: str
label: str
value: str
@dataclass(frozen=True, slots=True)
class DocumentInsight:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
fields: tuple[DocumentField, ...] = ()
classification_source: str = "rule"
classification_confidence: float = 0.0
evidence: tuple[str, ...] = ()
warnings: tuple[str, ...] = ()
@dataclass(frozen=True, slots=True)
class DocumentRule:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
keywords: tuple[str, ...]
score_bias: float = 0.0
@dataclass(frozen=True, slots=True)
class RuleMatch:
rule: DocumentRule | None
confidence: float
evidence: tuple[str, ...]
score: float
class LlmDocumentClassification(BaseModel):
document_type: str = Field(default="other")
scene_code: str = Field(default="other")
scene_label: str = Field(default="其他票据")
expense_type: str = Field(default="other")
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
evidence: list[str] = Field(default_factory=list)
fields: list[DocumentField] = Field(default_factory=list)
DEFAULT_RULE = DocumentRule(
document_type="other",
document_type_label="其他单据",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=(),
score_bias=0.0,
)
DOCUMENT_RULES: tuple[DocumentRule, ...] = (
DocumentRule(
document_type="flight_itinerary",
document_type_label="机票/航班行程单",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"),
score_bias=0.34,
),
DocumentRule(
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("铁路电子客票", "电子客票", "高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座", "票价"),
score_bias=0.32,
),
DocumentRule(
document_type="hotel_invoice",
document_type_label="酒店住宿票据",
scene_code="hotel",
scene_label="住宿票据",
expense_type="hotel",
keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"),
score_bias=0.16,
),
DocumentRule(
document_type="taxi_receipt",
document_type_label="出租车/网约车票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "乘车", "用车", "叫车", "车费", "车资", "的士", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"),
score_bias=0.38,
),
DocumentRule(
document_type="parking_toll_receipt",
document_type_label="停车/通行费票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"),
score_bias=0.28,
),
DocumentRule(
document_type="meal_receipt",
document_type_label="餐饮票据",
scene_code="meal",
scene_label="餐饮票据",
expense_type="meal",
keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"),
score_bias=0.14,
),
DocumentRule(
document_type="office_invoice",
document_type_label="办公用品票据",
scene_code="office",
scene_label="办公用品票据",
expense_type="office",
keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"),
score_bias=0.14,
),
DocumentRule(
document_type="meeting_invoice",
document_type_label="会议/会务票据",
scene_code="meeting",
scene_label="会务票据",
expense_type="meeting",
keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"),
score_bias=0.12,
),
DocumentRule(
document_type="training_invoice",
document_type_label="培训票据",
scene_code="training",
scene_label="培训票据",
expense_type="training",
keywords=("培训", "课程", "讲师", "教材", "学费", "认证"),
score_bias=0.12,
),
DocumentRule(
document_type="vat_invoice",
document_type_label="增值税发票",
scene_code="other",
scene_label="通用发票",
expense_type="other",
keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"),
score_bias=-0.08,
),
DocumentRule(
document_type="receipt",
document_type_label="一般收据/凭证",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=("收据", "凭证", "票据"),
score_bias=-0.18,
),
)
DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES}
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
AMOUNT_PATTERNS = (
re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
),
re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
)
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[:\s]*([A-Za-z0-9-]{6,24})")
INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[:\s]*([A-Za-z0-9-]{6,24})")
TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[:\s]*([A-Za-z0-9]{2,12})")
ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})")
MERCHANT_PATTERNS = (
re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[:\s]*([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40})"),
re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"),
)
class DocumentIntelligenceService:
def __init__(self, db: Session | None = None) -> None:
self.db = db
def build_document_insight(
self,
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
raw_text = " ".join(
[str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()]
).strip()
compact = re.sub(r"\s+", "", raw_text).lower()
rule_match = _match_document_rule(compact)
base_rule = rule_match.rule or DEFAULT_RULE
fields = tuple(_extract_document_fields(raw_text))
rule_insight = DocumentInsight(
document_type=base_rule.document_type,
document_type_label=base_rule.document_type_label,
scene_code=base_rule.scene_code,
scene_label=base_rule.scene_label,
expense_type=base_rule.expense_type,
fields=fields,
classification_source="rule",
classification_confidence=rule_match.confidence,
evidence=rule_match.evidence,
)
llm_result = self._classify_with_model(
filename=str(filename or "").strip(),
summary=str(summary or "").strip(),
text=str(text or "").strip(),
preview_data_url=str(preview_data_url or "").strip(),
rule_insight=rule_insight,
fields=fields,
)
if llm_result is None:
return rule_insight
return self._merge_rule_and_model(
rule_insight=rule_insight,
llm_result=llm_result,
fields=fields,
has_preview=bool(preview_data_url),
)
def _classify_with_model(
self,
*,
filename: str,
summary: str,
text: str,
preview_data_url: str,
rule_insight: DocumentInsight,
fields: tuple[DocumentField, ...],
) -> tuple[str, LlmDocumentClassification] | None:
return None
@staticmethod
def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None:
payload_json = _extract_json_payload(response_text)
if payload_json is None:
return None
try:
parsed = LlmDocumentClassification.model_validate(payload_json)
except ValidationError:
return None
normalized_type = str(parsed.document_type or "other").strip().lower() or "other"
if normalized_type not in SUPPORTED_DOCUMENT_TYPES:
normalized_type = "other"
base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE)
evidence = [
str(item or "").strip()
for item in parsed.evidence
if str(item or "").strip()
][:4]
normalized_fields = _normalize_llm_document_fields(parsed.fields)
return LlmDocumentClassification(
document_type=normalized_type,
scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code,
scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label,
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
confidence=float(parsed.confidence),
evidence=evidence,
fields=normalized_fields,
)
@staticmethod
def _merge_rule_and_model(
*,
rule_insight: DocumentInsight,
llm_result: tuple[str, LlmDocumentClassification],
fields: tuple[DocumentField, ...],
has_preview: bool,
) -> DocumentInsight:
source, parsed = llm_result
warnings = list(rule_insight.warnings)
merged_fields = rule_insight.fields
if parsed.fields and (has_preview or parsed.confidence >= 0.55):
merged_fields = _merge_document_fields(rule_insight.fields, tuple(parsed.fields))
if merged_fields != rule_insight.fields:
warnings.append("票据关键信息已结合大模型复核结果修正,建议人工再核对原图。")
if parsed.confidence < 0.55:
if merged_fields == rule_insight.fields:
return rule_insight
return DocumentInsight(
document_type=rule_insight.document_type,
document_type_label=rule_insight.document_type_label,
scene_code=rule_insight.scene_code,
scene_label=rule_insight.scene_label,
expense_type=rule_insight.expense_type,
fields=merged_fields,
classification_source=rule_insight.classification_source,
classification_confidence=rule_insight.classification_confidence,
evidence=rule_insight.evidence,
warnings=tuple(warnings),
)
should_override = False
if parsed.document_type == rule_insight.document_type:
should_override = True
elif rule_insight.document_type == "other" and parsed.document_type != "other":
should_override = True
elif parsed.document_type != "other":
threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12)
should_override = parsed.confidence >= threshold
if not should_override:
if merged_fields == rule_insight.fields:
return rule_insight
return DocumentInsight(
document_type=rule_insight.document_type,
document_type_label=rule_insight.document_type_label,
scene_code=rule_insight.scene_code,
scene_label=rule_insight.scene_label,
expense_type=rule_insight.expense_type,
fields=merged_fields,
classification_source=rule_insight.classification_source,
classification_confidence=rule_insight.classification_confidence,
evidence=rule_insight.evidence,
warnings=tuple(warnings),
)
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
if parsed.document_type != rule_insight.document_type:
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
return DocumentInsight(
document_type=rule.document_type,
document_type_label=rule.document_type_label,
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
fields=merged_fields,
classification_source=source,
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
evidence=tuple(parsed.evidence or rule_insight.evidence),
warnings=tuple(warnings),
)
def build_document_insight(
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
return DocumentIntelligenceService().build_document_insight(
filename=filename,
summary=summary,
text=text,
preview_data_url=preview_data_url,
)
def _match_document_rule(compact_text: str) -> RuleMatch:
best_rule = DEFAULT_RULE
best_evidence: tuple[str, ...] = ()
best_score = 0.0
for rule in DOCUMENT_RULES:
matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text)
if not matched:
continue
score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched)
if score > best_score:
best_rule = rule
best_evidence = matched
best_score = score
if best_score <= 0:
return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0)
confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12)
return RuleMatch(
rule=best_rule,
confidence=round(confidence, 2),
evidence=best_evidence[:4],
score=best_score,
)
def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
if not response_text:
return None
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
if not cleaned:
return None
fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL)
candidates = [fenced_match.group(1)] if fenced_match else []
candidates.append(cleaned)
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidates.append(cleaned[start : end + 1])
for candidate in candidates:
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def _normalize_llm_document_fields(raw_fields: list[DocumentField] | list[dict[str, Any]]) -> list[DocumentField]:
normalized: list[DocumentField] = []
seen_keys: set[str] = set()
for field in raw_fields:
raw_key = str(getattr(field, "key", "") if isinstance(field, DocumentField) else field.get("key") or "").strip()
raw_label = str(getattr(field, "label", "") if isinstance(field, DocumentField) else field.get("label") or "").strip()
raw_value = str(getattr(field, "value", "") if isinstance(field, DocumentField) else field.get("value") or "").strip()
key = _normalize_llm_document_field_key(raw_key, raw_label)
if not key or key in seen_keys:
continue
value = _normalize_llm_document_field_value(key, raw_value)
if not value:
continue
seen_keys.add(key)
normalized.append(
DocumentField(
key=key,
label=_llm_document_field_label(key),
value=value,
)
)
return normalized
def _normalize_llm_document_field_key(key: str, label: str) -> str:
compact_key = str(key or "").strip().lower()
compact_label = str(label or "").replace(" ", "").lower()
if compact_key in {"amount", "total_amount", "payment_amount", "paid_amount"} or any(
token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return "amount"
if compact_key in {"date", "time", "issued_at", "invoice_date"} or any(
token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")
):
return "date"
if compact_key in {"merchant_name", "merchant", "seller_name", "vendor_name"} or any(
token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方")
):
return "merchant_name"
if compact_key in {"invoice_number", "ticket_number", "order_no", "order_number"} or any(
token in compact_label for token in ("票据号码", "发票号码", "票号", "单号", "订单号")
):
return "invoice_number"
if compact_key in {"invoice_code"} or "发票代码" in compact_label:
return "invoice_code"
if compact_key in {"trip_no", "flight_no", "train_no"} or any(
token in compact_label for token in ("车次", "航班")
):
return "trip_no"
if compact_key in {"route", "trip_route"} or any(token in compact_label for token in ("行程", "路线")):
return "route"
return ""
def _normalize_llm_document_field_value(key: str, value: str) -> str:
raw_value = str(value or "").strip()
if not raw_value:
return ""
if key == "amount":
amount = _extract_amount(raw_value)
if amount:
return amount
cleaned = raw_value.replace("", "").replace("¥", "").replace("", "").replace(",", ".").strip()
try:
candidate = Decimal(cleaned)
except InvalidOperation:
return ""
if candidate <= Decimal("0.00"):
return ""
text_value = format(candidate.quantize(Decimal("0.01")), "f").rstrip("0").rstrip(".")
return f"{text_value}"
if key == "date":
return _extract_date(raw_value) or _clean_field_value(raw_value)
if key == "route":
return _extract_route(raw_value) or _clean_field_value(
raw_value.replace("", "-").replace("", "-").replace("->", "-")
)
return _clean_field_value(raw_value)
def _llm_document_field_label(key: str) -> str:
return {
"amount": "金额",
"date": "日期",
"merchant_name": "商户",
"invoice_number": "票据号码",
"invoice_code": "发票代码",
"trip_no": "车次/航班",
"route": "行程",
}.get(key, key)
def _merge_document_fields(
base_fields: tuple[DocumentField, ...],
override_fields: tuple[DocumentField, ...],
) -> tuple[DocumentField, ...]:
merged = {field.key: field for field in base_fields if field.key and field.value}
order = [field.key for field in base_fields if field.key and field.value]
for field in override_fields:
if not field.key or not field.value:
continue
merged[field.key] = field
if field.key not in order:
order.append(field.key)
return tuple(merged[key] for key in order if key in merged)
def _extract_document_fields(text: str) -> list[DocumentField]:
fields: list[DocumentField] = []
amount = _extract_amount(text)
if amount:
fields.append(DocumentField(key="amount", label="金额", value=amount))
date_value = _extract_date(text)
if date_value:
fields.append(DocumentField(key="date", label="日期", value=date_value))
merchant = _extract_merchant(text)
if merchant:
fields.append(DocumentField(key="merchant_name", label="商户", value=merchant))
invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text)
if invoice_number:
fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number))
invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text)
if invoice_code:
fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code))
trip_no = _extract_pattern(TRIP_NO_PATTERN, text)
if trip_no:
fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no))
route = _extract_route(text)
if route:
fields.append(DocumentField(key="route", label="行程", value=route))
return fields
def _extract_amount(text: str) -> str:
best_value: Decimal | None = None
for pattern in AMOUNT_PATTERNS:
for match in pattern.finditer(text):
raw_value = str(match.group(1) or "").replace(",", ".").strip()
try:
candidate = Decimal(raw_value)
except InvalidOperation:
continue
if candidate <= Decimal("0.00"):
continue
if best_value is None or candidate > best_value:
best_value = candidate
if best_value is None:
return ""
normalized = best_value.quantize(Decimal("0.01"))
text_value = format(normalized, "f").rstrip("0").rstrip(".")
return f"{text_value}"
def _extract_date(text: str) -> str:
match = DATE_PATTERN.search(text)
if not match:
return ""
raw_value = str(match.group(1) or "").strip()
normalized = raw_value.replace("", "-").replace("", "-").replace("", "")
normalized = normalized.replace("/", "-").replace(".", "-")
parts = [part for part in normalized.split("-") if part]
if len(parts) != 3:
return raw_value
year, month, day = parts
return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}"
def _extract_merchant(text: str) -> str:
for pattern in MERCHANT_PATTERNS:
match = pattern.search(text)
if not match:
continue
value = _clean_field_value(match.group(1))
if value:
return value
return ""
def _extract_route(text: str) -> str:
match = ROUTE_PATTERN.search(text)
if not match:
return ""
start = _clean_field_value(match.group(1))
end = _clean_field_value(match.group(2))
if not start or not end or start == end:
return ""
return f"{start}-{end}"
def _extract_pattern(pattern: re.Pattern[str], text: str) -> str:
match = pattern.search(text)
if not match:
return ""
return _clean_field_value(match.group(1))
def _clean_field_value(value: str) -> str:
return str(value or "").strip().strip(":,。.;")