from __future__ import annotations import json import re from dataclasses import dataclass from decimal import Decimal, InvalidOperation from typing import Any from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session from app.services.runtime_chat import RuntimeChatService @dataclass(frozen=True, slots=True) class DocumentField: key: str label: str value: str @dataclass(frozen=True, slots=True) class DocumentInsight: document_type: str document_type_label: str scene_code: str scene_label: str expense_type: str fields: tuple[DocumentField, ...] = () classification_source: str = "rule" classification_confidence: float = 0.0 evidence: tuple[str, ...] = () warnings: tuple[str, ...] = () @dataclass(frozen=True, slots=True) class DocumentRule: document_type: str document_type_label: str scene_code: str scene_label: str expense_type: str keywords: tuple[str, ...] score_bias: float = 0.0 @dataclass(frozen=True, slots=True) class RuleMatch: rule: DocumentRule | None confidence: float evidence: tuple[str, ...] score: float class LlmDocumentClassification(BaseModel): document_type: str = Field(default="other") scene_code: str = Field(default="other") scene_label: str = Field(default="其他票据") expense_type: str = Field(default="other") confidence: float = Field(default=0.0, ge=0.0, le=1.0) evidence: list[str] = Field(default_factory=list) DEFAULT_RULE = DocumentRule( document_type="other", document_type_label="其他单据", scene_code="other", scene_label="其他票据", expense_type="other", keywords=(), score_bias=0.0, ) DOCUMENT_RULES: tuple[DocumentRule, ...] = ( DocumentRule( document_type="flight_itinerary", document_type_label="机票/航班行程单", scene_code="travel", scene_label="差旅票据", expense_type="travel", keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"), score_bias=0.34, ), DocumentRule( document_type="train_ticket", document_type_label="火车/高铁票", scene_code="travel", scene_label="差旅票据", expense_type="travel", keywords=("高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座"), score_bias=0.32, ), DocumentRule( document_type="hotel_invoice", document_type_label="酒店住宿票据", scene_code="hotel", scene_label="住宿票据", expense_type="hotel", keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"), score_bias=0.16, ), DocumentRule( document_type="taxi_receipt", document_type_label="出租车/网约车票据", scene_code="transport", scene_label="交通票据", expense_type="transport", keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"), score_bias=0.38, ), DocumentRule( document_type="parking_toll_receipt", document_type_label="停车/通行费票据", scene_code="transport", scene_label="交通票据", expense_type="transport", keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"), score_bias=0.28, ), DocumentRule( document_type="meal_receipt", document_type_label="餐饮票据", scene_code="meal", scene_label="餐饮票据", expense_type="meal", keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"), score_bias=0.14, ), DocumentRule( document_type="office_invoice", document_type_label="办公用品票据", scene_code="office", scene_label="办公用品票据", expense_type="office", keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"), score_bias=0.14, ), DocumentRule( document_type="meeting_invoice", document_type_label="会议/会务票据", scene_code="meeting", scene_label="会务票据", expense_type="meeting", keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"), score_bias=0.12, ), DocumentRule( document_type="training_invoice", document_type_label="培训票据", scene_code="training", scene_label="培训票据", expense_type="training", keywords=("培训", "课程", "讲师", "教材", "学费", "认证"), score_bias=0.12, ), DocumentRule( document_type="vat_invoice", document_type_label="增值税发票", scene_code="other", scene_label="通用发票", expense_type="other", keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"), score_bias=-0.08, ), DocumentRule( document_type="receipt", document_type_label="一般收据/凭证", scene_code="other", scene_label="其他票据", expense_type="other", keywords=("收据", "凭证", "票据"), score_bias=-0.18, ), ) DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES} SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",) AMOUNT_PATTERNS = ( re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[::\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"), re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"), ) DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)") INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[::\s]*([A-Za-z0-9-]{6,24})") INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[::\s]*([A-Za-z0-9-]{6,24})") TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[::\s]*([A-Za-z0-9]{2,12})") ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})") MERCHANT_PATTERNS = ( re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[::\s]*([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40})"), re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"), ) class DocumentIntelligenceService: def __init__(self, db: Session | None = None) -> None: self.runtime_chat_service = RuntimeChatService(db) if db is not None else None def build_document_insight( self, *, filename: str = "", summary: str = "", text: str = "", preview_data_url: str = "", ) -> DocumentInsight: raw_text = " ".join( [str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()] ).strip() compact = re.sub(r"\s+", "", raw_text).lower() rule_match = _match_document_rule(compact) base_rule = rule_match.rule or DEFAULT_RULE fields = tuple(_extract_document_fields(raw_text)) rule_insight = DocumentInsight( document_type=base_rule.document_type, document_type_label=base_rule.document_type_label, scene_code=base_rule.scene_code, scene_label=base_rule.scene_label, expense_type=base_rule.expense_type, fields=fields, classification_source="rule", classification_confidence=rule_match.confidence, evidence=rule_match.evidence, ) llm_result = self._classify_with_model( filename=str(filename or "").strip(), summary=str(summary or "").strip(), text=str(text or "").strip(), preview_data_url=str(preview_data_url or "").strip(), rule_insight=rule_insight, fields=fields, ) if llm_result is None: return rule_insight return self._merge_rule_and_model( rule_insight=rule_insight, llm_result=llm_result, fields=fields, has_preview=bool(preview_data_url), ) def _classify_with_model( self, *, filename: str, summary: str, text: str, preview_data_url: str, rule_insight: DocumentInsight, fields: tuple[DocumentField, ...], ) -> tuple[str, LlmDocumentClassification] | None: if self.runtime_chat_service is None: return None trimmed_text = text.strip() if not trimmed_text and not summary.strip(): return None facts = { "filename": filename, "summary": summary[:300], "ocr_text_excerpt": trimmed_text[:2000], "rule_candidate": { "document_type": rule_insight.document_type, "document_type_label": rule_insight.document_type_label, "scene_code": rule_insight.scene_code, "scene_label": rule_insight.scene_label, "expense_type": rule_insight.expense_type, "confidence": round(rule_insight.classification_confidence, 2), "evidence": list(rule_insight.evidence), }, "extracted_fields": [ {"key": field.key, "label": field.label, "value": field.value} for field in fields ], "allowed_document_types": list(SUPPORTED_DOCUMENT_TYPES), } system_prompt = ( "你是企业报销票据识别复核器。" "你的任务不是 OCR,而是在已有 OCR 文本和票据预览基础上判断票据类型。" "只输出 JSON 对象,不要输出 Markdown、解释或代码块。" "document_type 只能是:" f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}。" "如果证据不足,返回 other。" "严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。" "如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。" "例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。" "输出字段:document_type, scene_code, scene_label, expense_type, confidence, evidence。" ) user_prompt = ( "请根据以下票据事实给出最终分类 JSON:\n" f"{json.dumps(facts, ensure_ascii=False, indent=2)}\n\n" "示例输出:\n" "{\n" ' "document_type": "taxi_receipt",\n' ' "scene_code": "transport",\n' ' "scene_label": "交通票据",\n' ' "expense_type": "transport",\n' ' "confidence": 0.86,\n' ' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n' "}" ) if preview_data_url: response_text = self.runtime_chat_service.complete( [ {"role": "system", "content": system_prompt}, { "role": "user", "content": [ {"type": "text", "text": user_prompt}, {"type": "image_url", "image_url": {"url": preview_data_url}}, ], }, ], slot_priority=("vlm",), max_tokens=320, temperature=0.0, ) parsed = self._parse_llm_payload(response_text) if parsed is not None: return "llm_vision", parsed response_text = self.runtime_chat_service.complete( [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], slot_priority=("main", "backup"), max_tokens=320, temperature=0.0, ) parsed = self._parse_llm_payload(response_text) if parsed is not None: return "llm_text", parsed return None @staticmethod def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None: payload_json = _extract_json_payload(response_text) if payload_json is None: return None try: parsed = LlmDocumentClassification.model_validate(payload_json) except ValidationError: return None normalized_type = str(parsed.document_type or "other").strip().lower() or "other" if normalized_type not in SUPPORTED_DOCUMENT_TYPES: normalized_type = "other" base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE) evidence = [ str(item or "").strip() for item in parsed.evidence if str(item or "").strip() ][:4] return LlmDocumentClassification( document_type=normalized_type, scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code, scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label, expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type, confidence=float(parsed.confidence), evidence=evidence, ) @staticmethod def _merge_rule_and_model( *, rule_insight: DocumentInsight, llm_result: tuple[str, LlmDocumentClassification], fields: tuple[DocumentField, ...], has_preview: bool, ) -> DocumentInsight: source, parsed = llm_result if parsed.confidence < 0.55: return rule_insight should_override = False if parsed.document_type == rule_insight.document_type: should_override = True elif rule_insight.document_type == "other" and parsed.document_type != "other": should_override = True elif parsed.document_type != "other": threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12) should_override = parsed.confidence >= threshold if not should_override: return rule_insight rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE) warnings = list(rule_insight.warnings) if parsed.document_type != rule_insight.document_type: warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。") return DocumentInsight( document_type=rule.document_type, document_type_label=rule.document_type_label, scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code, scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label, expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type, fields=fields, classification_source=source, classification_confidence=max(parsed.confidence, rule_insight.classification_confidence), evidence=tuple(parsed.evidence or rule_insight.evidence), warnings=tuple(warnings), ) def build_document_insight( *, filename: str = "", summary: str = "", text: str = "", preview_data_url: str = "", ) -> DocumentInsight: return DocumentIntelligenceService().build_document_insight( filename=filename, summary=summary, text=text, preview_data_url=preview_data_url, ) def _match_document_rule(compact_text: str) -> RuleMatch: best_rule = DEFAULT_RULE best_evidence: tuple[str, ...] = () best_score = 0.0 for rule in DOCUMENT_RULES: matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text) if not matched: continue score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched) if score > best_score: best_rule = rule best_evidence = matched best_score = score if best_score <= 0: return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0) confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12) return RuleMatch( rule=best_rule, confidence=round(confidence, 2), evidence=best_evidence[:4], score=best_score, ) def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None: if not response_text: return None cleaned = re.sub(r".*?", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() if not cleaned: return None fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL) candidates = [fenced_match.group(1)] if fenced_match else [] candidates.append(cleaned) start = cleaned.find("{") end = cleaned.rfind("}") if start != -1 and end != -1 and end > start: candidates.append(cleaned[start : end + 1]) for candidate in candidates: try: parsed = json.loads(candidate) except json.JSONDecodeError: continue if isinstance(parsed, dict): return parsed return None def _extract_document_fields(text: str) -> list[DocumentField]: fields: list[DocumentField] = [] amount = _extract_amount(text) if amount: fields.append(DocumentField(key="amount", label="金额", value=amount)) date_value = _extract_date(text) if date_value: fields.append(DocumentField(key="date", label="日期", value=date_value)) merchant = _extract_merchant(text) if merchant: fields.append(DocumentField(key="merchant_name", label="商户", value=merchant)) invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text) if invoice_number: fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number)) invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text) if invoice_code: fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code)) trip_no = _extract_pattern(TRIP_NO_PATTERN, text) if trip_no: fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no)) route = _extract_route(text) if route: fields.append(DocumentField(key="route", label="行程", value=route)) return fields def _extract_amount(text: str) -> str: best_value: Decimal | None = None for pattern in AMOUNT_PATTERNS: for match in pattern.finditer(text): raw_value = str(match.group(1) or "").replace(",", ".").strip() try: candidate = Decimal(raw_value) except InvalidOperation: continue if candidate <= Decimal("0.00"): continue if best_value is None or candidate > best_value: best_value = candidate if best_value is not None: break if best_value is None: return "" normalized = best_value.quantize(Decimal("0.01")) text_value = format(normalized, "f").rstrip("0").rstrip(".") return f"{text_value}元" def _extract_date(text: str) -> str: match = DATE_PATTERN.search(text) if not match: return "" raw_value = str(match.group(1) or "").strip() normalized = raw_value.replace("年", "-").replace("月", "-").replace("日", "") normalized = normalized.replace("/", "-").replace(".", "-") parts = [part for part in normalized.split("-") if part] if len(parts) != 3: return raw_value year, month, day = parts return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}" def _extract_merchant(text: str) -> str: for pattern in MERCHANT_PATTERNS: match = pattern.search(text) if not match: continue value = _clean_field_value(match.group(1)) if value: return value return "" def _extract_route(text: str) -> str: match = ROUTE_PATTERN.search(text) if not match: return "" start = _clean_field_value(match.group(1)) end = _clean_field_value(match.group(2)) if not start or not end or start == end: return "" return f"{start}-{end}" def _extract_pattern(pattern: re.Pattern[str], text: str) -> str: match = pattern.search(text) if not match: return "" return _clean_field_value(match.group(1)) def _clean_field_value(value: str) -> str: return str(value or "").strip().strip("::,,。.;;")