from __future__ import annotations import json import re from dataclasses import dataclass from decimal import Decimal, InvalidOperation from typing import Any from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session @dataclass(frozen=True, slots=True) class DocumentField: key: str label: str value: str @dataclass(frozen=True, slots=True) class DocumentInsight: document_type: str document_type_label: str scene_code: str scene_label: str expense_type: str fields: tuple[DocumentField, ...] = () classification_source: str = "rule" classification_confidence: float = 0.0 evidence: tuple[str, ...] = () warnings: tuple[str, ...] = () @dataclass(frozen=True, slots=True) class DocumentRule: document_type: str document_type_label: str scene_code: str scene_label: str expense_type: str keywords: tuple[str, ...] score_bias: float = 0.0 @dataclass(frozen=True, slots=True) class RuleMatch: rule: DocumentRule | None confidence: float evidence: tuple[str, ...] score: float class LlmDocumentClassification(BaseModel): document_type: str = Field(default="other") scene_code: str = Field(default="other") scene_label: str = Field(default="其他票据") expense_type: str = Field(default="other") confidence: float = Field(default=0.0, ge=0.0, le=1.0) evidence: list[str] = Field(default_factory=list) fields: list[DocumentField] = Field(default_factory=list) DEFAULT_RULE = DocumentRule( document_type="other", document_type_label="其他单据", scene_code="other", scene_label="其他票据", expense_type="other", keywords=(), score_bias=0.0, ) DOCUMENT_RULES: tuple[DocumentRule, ...] = ( DocumentRule( document_type="flight_itinerary", document_type_label="机票/航班行程单", scene_code="travel", scene_label="差旅票据", expense_type="travel", keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"), score_bias=0.34, ), DocumentRule( document_type="train_ticket", document_type_label="火车/高铁票", scene_code="travel", scene_label="差旅票据", expense_type="travel", keywords=("铁路电子客票", "电子客票", "高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座", "票价"), score_bias=0.32, ), DocumentRule( document_type="hotel_invoice", document_type_label="酒店住宿票据", scene_code="hotel", scene_label="住宿票据", expense_type="hotel", keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"), score_bias=0.16, ), DocumentRule( document_type="taxi_receipt", document_type_label="出租车/网约车票据", scene_code="transport", scene_label="交通票据", expense_type="transport", keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "乘车", "用车", "叫车", "车费", "车资", "的士", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"), score_bias=0.38, ), DocumentRule( document_type="parking_toll_receipt", document_type_label="停车/通行费票据", scene_code="transport", scene_label="交通票据", expense_type="transport", keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"), score_bias=0.28, ), DocumentRule( document_type="meal_receipt", document_type_label="餐饮票据", scene_code="meal", scene_label="餐饮票据", expense_type="meal", keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"), score_bias=0.14, ), DocumentRule( document_type="office_invoice", document_type_label="办公用品票据", scene_code="office", scene_label="办公用品票据", expense_type="office", keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"), score_bias=0.14, ), DocumentRule( document_type="meeting_invoice", document_type_label="会议/会务票据", scene_code="meeting", scene_label="会务票据", expense_type="meeting", keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"), score_bias=0.12, ), DocumentRule( document_type="training_invoice", document_type_label="培训票据", scene_code="training", scene_label="培训票据", expense_type="training", keywords=("培训", "课程", "讲师", "教材", "学费", "认证"), score_bias=0.12, ), DocumentRule( document_type="vat_invoice", document_type_label="增值税发票", scene_code="other", scene_label="通用发票", expense_type="other", keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"), score_bias=-0.08, ), DocumentRule( document_type="receipt", document_type_label="一般收据/凭证", scene_code="other", scene_label="其他票据", expense_type="other", keywords=("收据", "凭证", "票据"), score_bias=-0.18, ), ) DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES} SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",) AMOUNT_PATTERNS = ( re.compile( r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)" r"[::\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)" ), re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"), re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"), ) DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)") INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[::\s]*([A-Za-z0-9-]{6,24})") INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[::\s]*([A-Za-z0-9-]{6,24})") TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[::\s]*([A-Za-z0-9]{2,12})") ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})") MERCHANT_PATTERNS = ( re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[::\s]*([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40})"), re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"), ) class DocumentIntelligenceService: def __init__(self, db: Session | None = None) -> None: self.db = db def build_document_insight( self, *, filename: str = "", summary: str = "", text: str = "", preview_data_url: str = "", ) -> DocumentInsight: raw_text = " ".join( [str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()] ).strip() compact = re.sub(r"\s+", "", raw_text).lower() rule_match = _match_document_rule(compact) base_rule = rule_match.rule or DEFAULT_RULE fields = tuple(_extract_document_fields(raw_text)) rule_insight = DocumentInsight( document_type=base_rule.document_type, document_type_label=base_rule.document_type_label, scene_code=base_rule.scene_code, scene_label=base_rule.scene_label, expense_type=base_rule.expense_type, fields=fields, classification_source="rule", classification_confidence=rule_match.confidence, evidence=rule_match.evidence, ) llm_result = self._classify_with_model( filename=str(filename or "").strip(), summary=str(summary or "").strip(), text=str(text or "").strip(), preview_data_url=str(preview_data_url or "").strip(), rule_insight=rule_insight, fields=fields, ) if llm_result is None: return rule_insight return self._merge_rule_and_model( rule_insight=rule_insight, llm_result=llm_result, fields=fields, has_preview=bool(preview_data_url), ) def _classify_with_model( self, *, filename: str, summary: str, text: str, preview_data_url: str, rule_insight: DocumentInsight, fields: tuple[DocumentField, ...], ) -> tuple[str, LlmDocumentClassification] | None: return None @staticmethod def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None: payload_json = _extract_json_payload(response_text) if payload_json is None: return None try: parsed = LlmDocumentClassification.model_validate(payload_json) except ValidationError: return None normalized_type = str(parsed.document_type or "other").strip().lower() or "other" if normalized_type not in SUPPORTED_DOCUMENT_TYPES: normalized_type = "other" base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE) evidence = [ str(item or "").strip() for item in parsed.evidence if str(item or "").strip() ][:4] normalized_fields = _normalize_llm_document_fields(parsed.fields) return LlmDocumentClassification( document_type=normalized_type, scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code, scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label, expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type, confidence=float(parsed.confidence), evidence=evidence, fields=normalized_fields, ) @staticmethod def _merge_rule_and_model( *, rule_insight: DocumentInsight, llm_result: tuple[str, LlmDocumentClassification], fields: tuple[DocumentField, ...], has_preview: bool, ) -> DocumentInsight: source, parsed = llm_result warnings = list(rule_insight.warnings) merged_fields = rule_insight.fields if parsed.fields and (has_preview or parsed.confidence >= 0.55): merged_fields = _merge_document_fields(rule_insight.fields, tuple(parsed.fields)) if merged_fields != rule_insight.fields: warnings.append("票据关键信息已结合大模型复核结果修正,建议人工再核对原图。") if parsed.confidence < 0.55: if merged_fields == rule_insight.fields: return rule_insight return DocumentInsight( document_type=rule_insight.document_type, document_type_label=rule_insight.document_type_label, scene_code=rule_insight.scene_code, scene_label=rule_insight.scene_label, expense_type=rule_insight.expense_type, fields=merged_fields, classification_source=rule_insight.classification_source, classification_confidence=rule_insight.classification_confidence, evidence=rule_insight.evidence, warnings=tuple(warnings), ) should_override = False if parsed.document_type == rule_insight.document_type: should_override = True elif rule_insight.document_type == "other" and parsed.document_type != "other": should_override = True elif parsed.document_type != "other": threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12) should_override = parsed.confidence >= threshold if not should_override: if merged_fields == rule_insight.fields: return rule_insight return DocumentInsight( document_type=rule_insight.document_type, document_type_label=rule_insight.document_type_label, scene_code=rule_insight.scene_code, scene_label=rule_insight.scene_label, expense_type=rule_insight.expense_type, fields=merged_fields, classification_source=rule_insight.classification_source, classification_confidence=rule_insight.classification_confidence, evidence=rule_insight.evidence, warnings=tuple(warnings), ) rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE) if parsed.document_type != rule_insight.document_type: warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。") return DocumentInsight( document_type=rule.document_type, document_type_label=rule.document_type_label, scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code, scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label, expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type, fields=merged_fields, classification_source=source, classification_confidence=max(parsed.confidence, rule_insight.classification_confidence), evidence=tuple(parsed.evidence or rule_insight.evidence), warnings=tuple(warnings), ) def build_document_insight( *, filename: str = "", summary: str = "", text: str = "", preview_data_url: str = "", ) -> DocumentInsight: return DocumentIntelligenceService().build_document_insight( filename=filename, summary=summary, text=text, preview_data_url=preview_data_url, ) def _match_document_rule(compact_text: str) -> RuleMatch: best_rule = DEFAULT_RULE best_evidence: tuple[str, ...] = () best_score = 0.0 for rule in DOCUMENT_RULES: matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text) if not matched: continue score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched) if score > best_score: best_rule = rule best_evidence = matched best_score = score if best_score <= 0: return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0) confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12) return RuleMatch( rule=best_rule, confidence=round(confidence, 2), evidence=best_evidence[:4], score=best_score, ) def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None: if not response_text: return None cleaned = re.sub(r".*?", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() if not cleaned: return None fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL) candidates = [fenced_match.group(1)] if fenced_match else [] candidates.append(cleaned) start = cleaned.find("{") end = cleaned.rfind("}") if start != -1 and end != -1 and end > start: candidates.append(cleaned[start : end + 1]) for candidate in candidates: try: parsed = json.loads(candidate) except json.JSONDecodeError: continue if isinstance(parsed, dict): return parsed return None def _normalize_llm_document_fields(raw_fields: list[DocumentField] | list[dict[str, Any]]) -> list[DocumentField]: normalized: list[DocumentField] = [] seen_keys: set[str] = set() for field in raw_fields: raw_key = str(getattr(field, "key", "") if isinstance(field, DocumentField) else field.get("key") or "").strip() raw_label = str(getattr(field, "label", "") if isinstance(field, DocumentField) else field.get("label") or "").strip() raw_value = str(getattr(field, "value", "") if isinstance(field, DocumentField) else field.get("value") or "").strip() key = _normalize_llm_document_field_key(raw_key, raw_label) if not key or key in seen_keys: continue value = _normalize_llm_document_field_value(key, raw_value) if not value: continue seen_keys.add(key) normalized.append( DocumentField( key=key, label=_llm_document_field_label(key), value=value, ) ) return normalized def _normalize_llm_document_field_key(key: str, label: str) -> str: compact_key = str(key or "").strip().lower() compact_label = str(label or "").replace(" ", "").lower() if compact_key in {"amount", "total_amount", "payment_amount", "paid_amount"} or any( token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额") ): return "amount" if compact_key in {"date", "time", "issued_at", "invoice_date"} or any( token in compact_label for token in ("日期", "时间", "开票日期", "发生时间") ): return "date" if compact_key in {"merchant_name", "merchant", "seller_name", "vendor_name"} or any( token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方") ): return "merchant_name" if compact_key in {"invoice_number", "ticket_number", "order_no", "order_number"} or any( token in compact_label for token in ("票据号码", "发票号码", "票号", "单号", "订单号") ): return "invoice_number" if compact_key in {"invoice_code"} or "发票代码" in compact_label: return "invoice_code" if compact_key in {"trip_no", "flight_no", "train_no"} or any( token in compact_label for token in ("车次", "航班") ): return "trip_no" if compact_key in {"route", "trip_route"} or any(token in compact_label for token in ("行程", "路线")): return "route" return "" def _normalize_llm_document_field_value(key: str, value: str) -> str: raw_value = str(value or "").strip() if not raw_value: return "" if key == "amount": amount = _extract_amount(raw_value) if amount: return amount cleaned = raw_value.replace("¥", "").replace("¥", "").replace("元", "").replace(",", ".").strip() try: candidate = Decimal(cleaned) except InvalidOperation: return "" if candidate <= Decimal("0.00"): return "" text_value = format(candidate.quantize(Decimal("0.01")), "f").rstrip("0").rstrip(".") return f"{text_value}元" if key == "date": return _extract_date(raw_value) or _clean_field_value(raw_value) if key == "route": return _extract_route(raw_value) or _clean_field_value( raw_value.replace("→", "-").replace("至", "-").replace("->", "-") ) return _clean_field_value(raw_value) def _llm_document_field_label(key: str) -> str: return { "amount": "金额", "date": "日期", "merchant_name": "商户", "invoice_number": "票据号码", "invoice_code": "发票代码", "trip_no": "车次/航班", "route": "行程", }.get(key, key) def _merge_document_fields( base_fields: tuple[DocumentField, ...], override_fields: tuple[DocumentField, ...], ) -> tuple[DocumentField, ...]: merged = {field.key: field for field in base_fields if field.key and field.value} order = [field.key for field in base_fields if field.key and field.value] for field in override_fields: if not field.key or not field.value: continue merged[field.key] = field if field.key not in order: order.append(field.key) return tuple(merged[key] for key in order if key in merged) def _extract_document_fields(text: str) -> list[DocumentField]: fields: list[DocumentField] = [] amount = _extract_amount(text) if amount: fields.append(DocumentField(key="amount", label="金额", value=amount)) date_value = _extract_date(text) if date_value: fields.append(DocumentField(key="date", label="日期", value=date_value)) merchant = _extract_merchant(text) if merchant: fields.append(DocumentField(key="merchant_name", label="商户", value=merchant)) invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text) if invoice_number: fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number)) invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text) if invoice_code: fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code)) trip_no = _extract_pattern(TRIP_NO_PATTERN, text) if trip_no: fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no)) route = _extract_route(text) if route: fields.append(DocumentField(key="route", label="行程", value=route)) return fields def _extract_amount(text: str) -> str: best_value: Decimal | None = None for pattern in AMOUNT_PATTERNS: for match in pattern.finditer(text): raw_value = str(match.group(1) or "").replace(",", ".").strip() try: candidate = Decimal(raw_value) except InvalidOperation: continue if candidate <= Decimal("0.00"): continue if best_value is None or candidate > best_value: best_value = candidate if best_value is None: return "" normalized = best_value.quantize(Decimal("0.01")) text_value = format(normalized, "f").rstrip("0").rstrip(".") return f"{text_value}元" def _extract_date(text: str) -> str: match = DATE_PATTERN.search(text) if not match: return "" raw_value = str(match.group(1) or "").strip() normalized = raw_value.replace("年", "-").replace("月", "-").replace("日", "") normalized = normalized.replace("/", "-").replace(".", "-") parts = [part for part in normalized.split("-") if part] if len(parts) != 3: return raw_value year, month, day = parts return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}" def _extract_merchant(text: str) -> str: for pattern in MERCHANT_PATTERNS: match = pattern.search(text) if not match: continue value = _clean_field_value(match.group(1)) if value: return value return "" def _extract_route(text: str) -> str: match = ROUTE_PATTERN.search(text) if not match: return "" start = _clean_field_value(match.group(1)) end = _clean_field_value(match.group(2)) if not start or not end or start == end: return "" return f"{start}-{end}" def _extract_pattern(pattern: re.Pattern[str], text: str) -> str: match = pattern.search(text) if not match: return "" return _clean_field_value(match.group(1)) def _clean_field_value(value: str) -> str: return str(value or "").strip().strip("::,,。.;;")