Files
X-Financial/server/src/app/services/document_intelligence.py

733 lines
28 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
import re
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from app.services.runtime_chat import RuntimeChatService
@dataclass(frozen=True, slots=True)
class DocumentField:
key: str
label: str
value: str
@dataclass(frozen=True, slots=True)
class DocumentInsight:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
fields: tuple[DocumentField, ...] = ()
classification_source: str = "rule"
classification_confidence: float = 0.0
evidence: tuple[str, ...] = ()
warnings: tuple[str, ...] = ()
@dataclass(frozen=True, slots=True)
class DocumentRule:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
keywords: tuple[str, ...]
score_bias: float = 0.0
@dataclass(frozen=True, slots=True)
class RuleMatch:
rule: DocumentRule | None
confidence: float
evidence: tuple[str, ...]
score: float
class LlmDocumentClassification(BaseModel):
document_type: str = Field(default="other")
scene_code: str = Field(default="other")
scene_label: str = Field(default="其他票据")
expense_type: str = Field(default="other")
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
evidence: list[str] = Field(default_factory=list)
fields: list[DocumentField] = Field(default_factory=list)
DEFAULT_RULE = DocumentRule(
document_type="other",
document_type_label="其他单据",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=(),
score_bias=0.0,
)
DOCUMENT_RULES: tuple[DocumentRule, ...] = (
DocumentRule(
document_type="flight_itinerary",
document_type_label="机票/航班行程单",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"),
score_bias=0.34,
),
DocumentRule(
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座"),
score_bias=0.32,
),
DocumentRule(
document_type="hotel_invoice",
document_type_label="酒店住宿票据",
scene_code="hotel",
scene_label="住宿票据",
expense_type="hotel",
keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"),
score_bias=0.16,
),
DocumentRule(
document_type="taxi_receipt",
document_type_label="出租车/网约车票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"),
score_bias=0.38,
),
DocumentRule(
document_type="parking_toll_receipt",
document_type_label="停车/通行费票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"),
score_bias=0.28,
),
DocumentRule(
document_type="meal_receipt",
document_type_label="餐饮票据",
scene_code="meal",
scene_label="餐饮票据",
expense_type="meal",
keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"),
score_bias=0.14,
),
DocumentRule(
document_type="office_invoice",
document_type_label="办公用品票据",
scene_code="office",
scene_label="办公用品票据",
expense_type="office",
keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"),
score_bias=0.14,
),
DocumentRule(
document_type="meeting_invoice",
document_type_label="会议/会务票据",
scene_code="meeting",
scene_label="会务票据",
expense_type="meeting",
keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"),
score_bias=0.12,
),
DocumentRule(
document_type="training_invoice",
document_type_label="培训票据",
scene_code="training",
scene_label="培训票据",
expense_type="training",
keywords=("培训", "课程", "讲师", "教材", "学费", "认证"),
score_bias=0.12,
),
DocumentRule(
document_type="vat_invoice",
document_type_label="增值税发票",
scene_code="other",
scene_label="通用发票",
expense_type="other",
keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"),
score_bias=-0.08,
),
DocumentRule(
document_type="receipt",
document_type_label="一般收据/凭证",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=("收据", "凭证", "票据"),
score_bias=-0.18,
),
)
DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES}
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
AMOUNT_PATTERNS = (
re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
),
re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
)
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[:\s]*([A-Za-z0-9-]{6,24})")
INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[:\s]*([A-Za-z0-9-]{6,24})")
TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[:\s]*([A-Za-z0-9]{2,12})")
ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})")
MERCHANT_PATTERNS = (
re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[:\s]*([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40})"),
re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"),
)
class DocumentIntelligenceService:
def __init__(self, db: Session | None = None) -> None:
self.runtime_chat_service = RuntimeChatService(db) if db is not None else None
def build_document_insight(
self,
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
raw_text = " ".join(
[str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()]
).strip()
compact = re.sub(r"\s+", "", raw_text).lower()
rule_match = _match_document_rule(compact)
base_rule = rule_match.rule or DEFAULT_RULE
fields = tuple(_extract_document_fields(raw_text))
rule_insight = DocumentInsight(
document_type=base_rule.document_type,
document_type_label=base_rule.document_type_label,
scene_code=base_rule.scene_code,
scene_label=base_rule.scene_label,
expense_type=base_rule.expense_type,
fields=fields,
classification_source="rule",
classification_confidence=rule_match.confidence,
evidence=rule_match.evidence,
)
llm_result = self._classify_with_model(
filename=str(filename or "").strip(),
summary=str(summary or "").strip(),
text=str(text or "").strip(),
preview_data_url=str(preview_data_url or "").strip(),
rule_insight=rule_insight,
fields=fields,
)
if llm_result is None:
return rule_insight
return self._merge_rule_and_model(
rule_insight=rule_insight,
llm_result=llm_result,
fields=fields,
has_preview=bool(preview_data_url),
)
def _classify_with_model(
self,
*,
filename: str,
summary: str,
text: str,
preview_data_url: str,
rule_insight: DocumentInsight,
fields: tuple[DocumentField, ...],
) -> tuple[str, LlmDocumentClassification] | None:
if self.runtime_chat_service is None:
return None
trimmed_text = text.strip()
if not trimmed_text and not summary.strip():
return None
facts = {
"filename": filename,
"summary": summary[:300],
"ocr_text_excerpt": trimmed_text[:2000],
"rule_candidate": {
"document_type": rule_insight.document_type,
"document_type_label": rule_insight.document_type_label,
"scene_code": rule_insight.scene_code,
"scene_label": rule_insight.scene_label,
"expense_type": rule_insight.expense_type,
"confidence": round(rule_insight.classification_confidence, 2),
"evidence": list(rule_insight.evidence),
},
"extracted_fields": [
{"key": field.key, "label": field.label, "value": field.value}
for field in fields
],
"allowed_document_types": list(SUPPORTED_DOCUMENT_TYPES),
}
system_prompt = (
"你是企业报销票据识别复核器。"
"你的任务不是 OCR而是在已有 OCR 文本和票据预览基础上判断票据类型,并尽量复核关键字段。"
"只输出 JSON 对象,不要输出 Markdown、解释或代码块。"
"document_type 只能是:"
f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}"
"如果证据不足,返回 other。"
"严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。"
"如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。"
"例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。"
"如果能从 OCR 或图片中明确确认字段,可在 fields 中返回。"
"fields 只允许包含 key, label, valuekey 只能是 amount, date, merchant_name, invoice_number, "
"invoice_code, trip_no, route。无法确认就不要返回该字段。"
"输出字段document_type, scene_code, scene_label, expense_type, confidence, evidence, fields。"
)
user_prompt = (
"请根据以下票据事实给出最终分类 JSON\n"
f"{json.dumps(facts, ensure_ascii=False, indent=2)}\n\n"
"示例输出:\n"
"{\n"
' "document_type": "taxi_receipt",\n'
' "scene_code": "transport",\n'
' "scene_label": "交通票据",\n'
' "expense_type": "transport",\n'
' "confidence": 0.86,\n'
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"],\n'
' "fields": [{"key": "amount", "label": "金额", "value": "32.5"}]\n'
"}"
)
if preview_data_url:
response_text = self.runtime_chat_service.complete(
[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": preview_data_url}},
],
},
],
slot_priority=("vlm",),
max_tokens=320,
temperature=0.0,
)
parsed = self._parse_llm_payload(response_text)
if parsed is not None:
return "llm_vision", parsed
response_text = self.runtime_chat_service.complete(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
slot_priority=("main", "backup"),
max_tokens=320,
temperature=0.0,
)
parsed = self._parse_llm_payload(response_text)
if parsed is not None:
return "llm_text", parsed
return None
@staticmethod
def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None:
payload_json = _extract_json_payload(response_text)
if payload_json is None:
return None
try:
parsed = LlmDocumentClassification.model_validate(payload_json)
except ValidationError:
return None
normalized_type = str(parsed.document_type or "other").strip().lower() or "other"
if normalized_type not in SUPPORTED_DOCUMENT_TYPES:
normalized_type = "other"
base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE)
evidence = [
str(item or "").strip()
for item in parsed.evidence
if str(item or "").strip()
][:4]
normalized_fields = _normalize_llm_document_fields(parsed.fields)
return LlmDocumentClassification(
document_type=normalized_type,
scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code,
scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label,
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
confidence=float(parsed.confidence),
evidence=evidence,
fields=normalized_fields,
)
@staticmethod
def _merge_rule_and_model(
*,
rule_insight: DocumentInsight,
llm_result: tuple[str, LlmDocumentClassification],
fields: tuple[DocumentField, ...],
has_preview: bool,
) -> DocumentInsight:
source, parsed = llm_result
warnings = list(rule_insight.warnings)
merged_fields = rule_insight.fields
if parsed.fields and (has_preview or parsed.confidence >= 0.55):
merged_fields = _merge_document_fields(rule_insight.fields, tuple(parsed.fields))
if merged_fields != rule_insight.fields:
warnings.append("票据关键信息已结合大模型复核结果修正,建议人工再核对原图。")
if parsed.confidence < 0.55:
if merged_fields == rule_insight.fields:
return rule_insight
return DocumentInsight(
document_type=rule_insight.document_type,
document_type_label=rule_insight.document_type_label,
scene_code=rule_insight.scene_code,
scene_label=rule_insight.scene_label,
expense_type=rule_insight.expense_type,
fields=merged_fields,
classification_source=rule_insight.classification_source,
classification_confidence=rule_insight.classification_confidence,
evidence=rule_insight.evidence,
warnings=tuple(warnings),
)
should_override = False
if parsed.document_type == rule_insight.document_type:
should_override = True
elif rule_insight.document_type == "other" and parsed.document_type != "other":
should_override = True
elif parsed.document_type != "other":
threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12)
should_override = parsed.confidence >= threshold
if not should_override:
if merged_fields == rule_insight.fields:
return rule_insight
return DocumentInsight(
document_type=rule_insight.document_type,
document_type_label=rule_insight.document_type_label,
scene_code=rule_insight.scene_code,
scene_label=rule_insight.scene_label,
expense_type=rule_insight.expense_type,
fields=merged_fields,
classification_source=rule_insight.classification_source,
classification_confidence=rule_insight.classification_confidence,
evidence=rule_insight.evidence,
warnings=tuple(warnings),
)
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
if parsed.document_type != rule_insight.document_type:
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
return DocumentInsight(
document_type=rule.document_type,
document_type_label=rule.document_type_label,
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
fields=merged_fields,
classification_source=source,
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
evidence=tuple(parsed.evidence or rule_insight.evidence),
warnings=tuple(warnings),
)
def build_document_insight(
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
return DocumentIntelligenceService().build_document_insight(
filename=filename,
summary=summary,
text=text,
preview_data_url=preview_data_url,
)
def _match_document_rule(compact_text: str) -> RuleMatch:
best_rule = DEFAULT_RULE
best_evidence: tuple[str, ...] = ()
best_score = 0.0
for rule in DOCUMENT_RULES:
matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text)
if not matched:
continue
score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched)
if score > best_score:
best_rule = rule
best_evidence = matched
best_score = score
if best_score <= 0:
return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0)
confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12)
return RuleMatch(
rule=best_rule,
confidence=round(confidence, 2),
evidence=best_evidence[:4],
score=best_score,
)
def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
if not response_text:
return None
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
if not cleaned:
return None
fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL)
candidates = [fenced_match.group(1)] if fenced_match else []
candidates.append(cleaned)
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidates.append(cleaned[start : end + 1])
for candidate in candidates:
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def _normalize_llm_document_fields(raw_fields: list[DocumentField] | list[dict[str, Any]]) -> list[DocumentField]:
normalized: list[DocumentField] = []
seen_keys: set[str] = set()
for field in raw_fields:
raw_key = str(getattr(field, "key", "") if isinstance(field, DocumentField) else field.get("key") or "").strip()
raw_label = str(getattr(field, "label", "") if isinstance(field, DocumentField) else field.get("label") or "").strip()
raw_value = str(getattr(field, "value", "") if isinstance(field, DocumentField) else field.get("value") or "").strip()
key = _normalize_llm_document_field_key(raw_key, raw_label)
if not key or key in seen_keys:
continue
value = _normalize_llm_document_field_value(key, raw_value)
if not value:
continue
seen_keys.add(key)
normalized.append(
DocumentField(
key=key,
label=_llm_document_field_label(key),
value=value,
)
)
return normalized
def _normalize_llm_document_field_key(key: str, label: str) -> str:
compact_key = str(key or "").strip().lower()
compact_label = str(label or "").replace(" ", "").lower()
if compact_key in {"amount", "total_amount", "payment_amount", "paid_amount"} or any(
token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return "amount"
if compact_key in {"date", "time", "issued_at", "invoice_date"} or any(
token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")
):
return "date"
if compact_key in {"merchant_name", "merchant", "seller_name", "vendor_name"} or any(
token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方")
):
return "merchant_name"
if compact_key in {"invoice_number", "ticket_number", "order_no", "order_number"} or any(
token in compact_label for token in ("票据号码", "发票号码", "票号", "单号", "订单号")
):
return "invoice_number"
if compact_key in {"invoice_code"} or "发票代码" in compact_label:
return "invoice_code"
if compact_key in {"trip_no", "flight_no", "train_no"} or any(
token in compact_label for token in ("车次", "航班")
):
return "trip_no"
if compact_key in {"route", "trip_route"} or any(token in compact_label for token in ("行程", "路线")):
return "route"
return ""
def _normalize_llm_document_field_value(key: str, value: str) -> str:
raw_value = str(value or "").strip()
if not raw_value:
return ""
if key == "amount":
amount = _extract_amount(raw_value)
if amount:
return amount
cleaned = raw_value.replace("", "").replace("¥", "").replace("", "").replace(",", ".").strip()
try:
candidate = Decimal(cleaned)
except InvalidOperation:
return ""
if candidate <= Decimal("0.00"):
return ""
text_value = format(candidate.quantize(Decimal("0.01")), "f").rstrip("0").rstrip(".")
return f"{text_value}"
if key == "date":
return _extract_date(raw_value) or _clean_field_value(raw_value)
if key == "route":
return _extract_route(raw_value) or _clean_field_value(
raw_value.replace("", "-").replace("", "-").replace("->", "-")
)
return _clean_field_value(raw_value)
def _llm_document_field_label(key: str) -> str:
return {
"amount": "金额",
"date": "日期",
"merchant_name": "商户",
"invoice_number": "票据号码",
"invoice_code": "发票代码",
"trip_no": "车次/航班",
"route": "行程",
}.get(key, key)
def _merge_document_fields(
base_fields: tuple[DocumentField, ...],
override_fields: tuple[DocumentField, ...],
) -> tuple[DocumentField, ...]:
merged = {field.key: field for field in base_fields if field.key and field.value}
order = [field.key for field in base_fields if field.key and field.value]
for field in override_fields:
if not field.key or not field.value:
continue
merged[field.key] = field
if field.key not in order:
order.append(field.key)
return tuple(merged[key] for key in order if key in merged)
def _extract_document_fields(text: str) -> list[DocumentField]:
fields: list[DocumentField] = []
amount = _extract_amount(text)
if amount:
fields.append(DocumentField(key="amount", label="金额", value=amount))
date_value = _extract_date(text)
if date_value:
fields.append(DocumentField(key="date", label="日期", value=date_value))
merchant = _extract_merchant(text)
if merchant:
fields.append(DocumentField(key="merchant_name", label="商户", value=merchant))
invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text)
if invoice_number:
fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number))
invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text)
if invoice_code:
fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code))
trip_no = _extract_pattern(TRIP_NO_PATTERN, text)
if trip_no:
fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no))
route = _extract_route(text)
if route:
fields.append(DocumentField(key="route", label="行程", value=route))
return fields
def _extract_amount(text: str) -> str:
best_value: Decimal | None = None
for pattern in AMOUNT_PATTERNS:
for match in pattern.finditer(text):
raw_value = str(match.group(1) or "").replace(",", ".").strip()
try:
candidate = Decimal(raw_value)
except InvalidOperation:
continue
if candidate <= Decimal("0.00"):
continue
if best_value is None or candidate > best_value:
best_value = candidate
if best_value is None:
return ""
normalized = best_value.quantize(Decimal("0.01"))
text_value = format(normalized, "f").rstrip("0").rstrip(".")
return f"{text_value}"
def _extract_date(text: str) -> str:
match = DATE_PATTERN.search(text)
if not match:
return ""
raw_value = str(match.group(1) or "").strip()
normalized = raw_value.replace("", "-").replace("", "-").replace("", "")
normalized = normalized.replace("/", "-").replace(".", "-")
parts = [part for part in normalized.split("-") if part]
if len(parts) != 3:
return raw_value
year, month, day = parts
return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}"
def _extract_merchant(text: str) -> str:
for pattern in MERCHANT_PATTERNS:
match = pattern.search(text)
if not match:
continue
value = _clean_field_value(match.group(1))
if value:
return value
return ""
def _extract_route(text: str) -> str:
match = ROUTE_PATTERN.search(text)
if not match:
return ""
start = _clean_field_value(match.group(1))
end = _clean_field_value(match.group(2))
if not start or not end or start == end:
return ""
return f"{start}-{end}"
def _extract_pattern(pattern: re.Pattern[str], text: str) -> str:
match = pattern.search(text)
if not match:
return ""
return _clean_field_value(match.group(1))
def _clean_field_value(value: str) -> str:
return str(value or "").strip().strip(":,。.;")