Files
X-Financial/server/src/app/services/document_intelligence.py

583 lines
21 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
import re
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from app.services.runtime_chat import RuntimeChatService
@dataclass(frozen=True, slots=True)
class DocumentField:
key: str
label: str
value: str
@dataclass(frozen=True, slots=True)
class DocumentInsight:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
fields: tuple[DocumentField, ...] = ()
classification_source: str = "rule"
classification_confidence: float = 0.0
evidence: tuple[str, ...] = ()
warnings: tuple[str, ...] = ()
@dataclass(frozen=True, slots=True)
class DocumentRule:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
keywords: tuple[str, ...]
score_bias: float = 0.0
@dataclass(frozen=True, slots=True)
class RuleMatch:
rule: DocumentRule | None
confidence: float
evidence: tuple[str, ...]
score: float
class LlmDocumentClassification(BaseModel):
document_type: str = Field(default="other")
scene_code: str = Field(default="other")
scene_label: str = Field(default="其他票据")
expense_type: str = Field(default="other")
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
evidence: list[str] = Field(default_factory=list)
DEFAULT_RULE = DocumentRule(
document_type="other",
document_type_label="其他单据",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=(),
score_bias=0.0,
)
DOCUMENT_RULES: tuple[DocumentRule, ...] = (
DocumentRule(
document_type="flight_itinerary",
document_type_label="机票/航班行程单",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"),
score_bias=0.34,
),
DocumentRule(
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座"),
score_bias=0.32,
),
DocumentRule(
document_type="hotel_invoice",
document_type_label="酒店住宿票据",
scene_code="hotel",
scene_label="住宿票据",
expense_type="hotel",
keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"),
score_bias=0.16,
),
DocumentRule(
document_type="taxi_receipt",
document_type_label="出租车/网约车票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"),
score_bias=0.38,
),
DocumentRule(
document_type="parking_toll_receipt",
document_type_label="停车/通行费票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"),
score_bias=0.28,
),
DocumentRule(
document_type="meal_receipt",
document_type_label="餐饮票据",
scene_code="meal",
scene_label="餐饮票据",
expense_type="meal",
keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"),
score_bias=0.14,
),
DocumentRule(
document_type="office_invoice",
document_type_label="办公用品票据",
scene_code="office",
scene_label="办公用品票据",
expense_type="office",
keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"),
score_bias=0.14,
),
DocumentRule(
document_type="meeting_invoice",
document_type_label="会议/会务票据",
scene_code="meeting",
scene_label="会务票据",
expense_type="meeting",
keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"),
score_bias=0.12,
),
DocumentRule(
document_type="training_invoice",
document_type_label="培训票据",
scene_code="training",
scene_label="培训票据",
expense_type="training",
keywords=("培训", "课程", "讲师", "教材", "学费", "认证"),
score_bias=0.12,
),
DocumentRule(
document_type="vat_invoice",
document_type_label="增值税发票",
scene_code="other",
scene_label="通用发票",
expense_type="other",
keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"),
score_bias=-0.08,
),
DocumentRule(
document_type="receipt",
document_type_label="一般收据/凭证",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=("收据", "凭证", "票据"),
score_bias=-0.18,
),
)
DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES}
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
AMOUNT_PATTERNS = (
re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[:\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
)
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[:\s]*([A-Za-z0-9-]{6,24})")
INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[:\s]*([A-Za-z0-9-]{6,24})")
TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[:\s]*([A-Za-z0-9]{2,12})")
ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})")
MERCHANT_PATTERNS = (
re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[:\s]*([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40})"),
re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"),
)
class DocumentIntelligenceService:
def __init__(self, db: Session | None = None) -> None:
self.runtime_chat_service = RuntimeChatService(db) if db is not None else None
def build_document_insight(
self,
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
raw_text = " ".join(
[str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()]
).strip()
compact = re.sub(r"\s+", "", raw_text).lower()
rule_match = _match_document_rule(compact)
base_rule = rule_match.rule or DEFAULT_RULE
fields = tuple(_extract_document_fields(raw_text))
rule_insight = DocumentInsight(
document_type=base_rule.document_type,
document_type_label=base_rule.document_type_label,
scene_code=base_rule.scene_code,
scene_label=base_rule.scene_label,
expense_type=base_rule.expense_type,
fields=fields,
classification_source="rule",
classification_confidence=rule_match.confidence,
evidence=rule_match.evidence,
)
llm_result = self._classify_with_model(
filename=str(filename or "").strip(),
summary=str(summary or "").strip(),
text=str(text or "").strip(),
preview_data_url=str(preview_data_url or "").strip(),
rule_insight=rule_insight,
fields=fields,
)
if llm_result is None:
return rule_insight
return self._merge_rule_and_model(
rule_insight=rule_insight,
llm_result=llm_result,
fields=fields,
has_preview=bool(preview_data_url),
)
def _classify_with_model(
self,
*,
filename: str,
summary: str,
text: str,
preview_data_url: str,
rule_insight: DocumentInsight,
fields: tuple[DocumentField, ...],
) -> tuple[str, LlmDocumentClassification] | None:
if self.runtime_chat_service is None:
return None
trimmed_text = text.strip()
if not trimmed_text and not summary.strip():
return None
facts = {
"filename": filename,
"summary": summary[:300],
"ocr_text_excerpt": trimmed_text[:2000],
"rule_candidate": {
"document_type": rule_insight.document_type,
"document_type_label": rule_insight.document_type_label,
"scene_code": rule_insight.scene_code,
"scene_label": rule_insight.scene_label,
"expense_type": rule_insight.expense_type,
"confidence": round(rule_insight.classification_confidence, 2),
"evidence": list(rule_insight.evidence),
},
"extracted_fields": [
{"key": field.key, "label": field.label, "value": field.value}
for field in fields
],
"allowed_document_types": list(SUPPORTED_DOCUMENT_TYPES),
}
system_prompt = (
"你是企业报销票据识别复核器。"
"你的任务不是 OCR而是在已有 OCR 文本和票据预览基础上判断票据类型。"
"只输出 JSON 对象,不要输出 Markdown、解释或代码块。"
"document_type 只能是:"
f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}"
"如果证据不足,返回 other。"
"严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。"
"如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。"
"例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。"
"输出字段document_type, scene_code, scene_label, expense_type, confidence, evidence。"
)
user_prompt = (
"请根据以下票据事实给出最终分类 JSON\n"
f"{json.dumps(facts, ensure_ascii=False, indent=2)}\n\n"
"示例输出:\n"
"{\n"
' "document_type": "taxi_receipt",\n'
' "scene_code": "transport",\n'
' "scene_label": "交通票据",\n'
' "expense_type": "transport",\n'
' "confidence": 0.86,\n'
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n'
"}"
)
if preview_data_url:
response_text = self.runtime_chat_service.complete(
[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": preview_data_url}},
],
},
],
slot_priority=("vlm",),
max_tokens=320,
temperature=0.0,
)
parsed = self._parse_llm_payload(response_text)
if parsed is not None:
return "llm_vision", parsed
response_text = self.runtime_chat_service.complete(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
slot_priority=("main", "backup"),
max_tokens=320,
temperature=0.0,
)
parsed = self._parse_llm_payload(response_text)
if parsed is not None:
return "llm_text", parsed
return None
@staticmethod
def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None:
payload_json = _extract_json_payload(response_text)
if payload_json is None:
return None
try:
parsed = LlmDocumentClassification.model_validate(payload_json)
except ValidationError:
return None
normalized_type = str(parsed.document_type or "other").strip().lower() or "other"
if normalized_type not in SUPPORTED_DOCUMENT_TYPES:
normalized_type = "other"
base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE)
evidence = [
str(item or "").strip()
for item in parsed.evidence
if str(item or "").strip()
][:4]
return LlmDocumentClassification(
document_type=normalized_type,
scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code,
scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label,
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
confidence=float(parsed.confidence),
evidence=evidence,
)
@staticmethod
def _merge_rule_and_model(
*,
rule_insight: DocumentInsight,
llm_result: tuple[str, LlmDocumentClassification],
fields: tuple[DocumentField, ...],
has_preview: bool,
) -> DocumentInsight:
source, parsed = llm_result
if parsed.confidence < 0.55:
return rule_insight
should_override = False
if parsed.document_type == rule_insight.document_type:
should_override = True
elif rule_insight.document_type == "other" and parsed.document_type != "other":
should_override = True
elif parsed.document_type != "other":
threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12)
should_override = parsed.confidence >= threshold
if not should_override:
return rule_insight
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
warnings = list(rule_insight.warnings)
if parsed.document_type != rule_insight.document_type:
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
return DocumentInsight(
document_type=rule.document_type,
document_type_label=rule.document_type_label,
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
fields=fields,
classification_source=source,
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
evidence=tuple(parsed.evidence or rule_insight.evidence),
warnings=tuple(warnings),
)
def build_document_insight(
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
return DocumentIntelligenceService().build_document_insight(
filename=filename,
summary=summary,
text=text,
preview_data_url=preview_data_url,
)
def _match_document_rule(compact_text: str) -> RuleMatch:
best_rule = DEFAULT_RULE
best_evidence: tuple[str, ...] = ()
best_score = 0.0
for rule in DOCUMENT_RULES:
matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text)
if not matched:
continue
score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched)
if score > best_score:
best_rule = rule
best_evidence = matched
best_score = score
if best_score <= 0:
return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0)
confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12)
return RuleMatch(
rule=best_rule,
confidence=round(confidence, 2),
evidence=best_evidence[:4],
score=best_score,
)
def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
if not response_text:
return None
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
if not cleaned:
return None
fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL)
candidates = [fenced_match.group(1)] if fenced_match else []
candidates.append(cleaned)
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidates.append(cleaned[start : end + 1])
for candidate in candidates:
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def _extract_document_fields(text: str) -> list[DocumentField]:
fields: list[DocumentField] = []
amount = _extract_amount(text)
if amount:
fields.append(DocumentField(key="amount", label="金额", value=amount))
date_value = _extract_date(text)
if date_value:
fields.append(DocumentField(key="date", label="日期", value=date_value))
merchant = _extract_merchant(text)
if merchant:
fields.append(DocumentField(key="merchant_name", label="商户", value=merchant))
invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text)
if invoice_number:
fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number))
invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text)
if invoice_code:
fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code))
trip_no = _extract_pattern(TRIP_NO_PATTERN, text)
if trip_no:
fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no))
route = _extract_route(text)
if route:
fields.append(DocumentField(key="route", label="行程", value=route))
return fields
def _extract_amount(text: str) -> str:
best_value: Decimal | None = None
for pattern in AMOUNT_PATTERNS:
for match in pattern.finditer(text):
raw_value = str(match.group(1) or "").replace(",", ".").strip()
try:
candidate = Decimal(raw_value)
except InvalidOperation:
continue
if candidate <= Decimal("0.00"):
continue
if best_value is None or candidate > best_value:
best_value = candidate
if best_value is not None:
break
if best_value is None:
return ""
normalized = best_value.quantize(Decimal("0.01"))
text_value = format(normalized, "f").rstrip("0").rstrip(".")
return f"{text_value}"
def _extract_date(text: str) -> str:
match = DATE_PATTERN.search(text)
if not match:
return ""
raw_value = str(match.group(1) or "").strip()
normalized = raw_value.replace("", "-").replace("", "-").replace("", "")
normalized = normalized.replace("/", "-").replace(".", "-")
parts = [part for part in normalized.split("-") if part]
if len(parts) != 3:
return raw_value
year, month, day = parts
return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}"
def _extract_merchant(text: str) -> str:
for pattern in MERCHANT_PATTERNS:
match = pattern.search(text)
if not match:
continue
value = _clean_field_value(match.group(1))
if value:
return value
return ""
def _extract_route(text: str) -> str:
match = ROUTE_PATTERN.search(text)
if not match:
return ""
start = _clean_field_value(match.group(1))
end = _clean_field_value(match.group(2))
if not start or not end or start == end:
return ""
return f"{start}-{end}"
def _extract_pattern(pattern: re.Pattern[str], text: str) -> str:
match = pattern.search(text)
if not match:
return ""
return _clean_field_value(match.group(1))
def _clean_field_value(value: str) -> str:
return str(value or "").strip().strip(":,。.;")