from __future__ import annotations import json import re from decimal import Decimal, InvalidOperation from typing import Any from pydantic import ValidationError from sqlalchemy.orm import Session from app.services.document_intelligence_rules import DEFAULT_RULE, DOCUMENT_RULES, DOCUMENT_TYPE_RULE_MAP, SUPPORTED_DOCUMENT_TYPES from app.services.document_intelligence_types import ( DocumentField, DocumentInsight, LlmDocumentClassification, RuleMatch, ) AMOUNT_PATTERNS = ( re.compile( r"(?:价税合计|合计金额|费用合计|总费用|费用总计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额|房费|住宿费)" r"[::\s¥¥人民币为是]*([0-9]+(?:[.,][0-9]{1,2})?)" ), re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"), re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"), ) DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)") TIME_PATTERN = re.compile(r"(?|-)\s*([\u4e00-\u9fa5]{2,12})") MERCHANT_PATTERNS = ( re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[::\s]*([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40})"), re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"), ) DATE_FIELD_KEYS = { "date", "time", "issued_at", "invoice_date", "issue_date", "travel_date", "trip_date", "journey_date", "departure_date", "departure_time", "depart_date", "depart_time", "boarding_date", "boarding_time", "train_date", "train_time", "train_departure_time", "scheduled_departure_time", "flight_date", "flight_time", "ride_date", "ride_time", "pickup_time", "start_time", } TRIP_DATE_LABEL_BY_DOCUMENT_TYPE = { "train_ticket": "列车出发时间", "flight_itinerary": "起飞日期", "taxi_receipt": "乘车时间", "transport_receipt": "乘车时间", "parking_toll_receipt": "通行日期", } TRIP_DATE_FIELD_LABEL_TOKENS = ( "日期", "时间", "开票日期", "发生时间", "行程日期", "出发日期", "出发时间", "列车出发时间", "发车日期", "发车时间", "开车时间", "乘车日期", "乘车时间", "起飞日期", "航班日期", "上车时间", "用车时间", ) class DocumentIntelligenceService: def __init__(self, db: Session | None = None) -> None: self.db = db def build_document_insight( self, *, filename: str = "", summary: str = "", text: str = "", preview_data_url: str = "", ) -> DocumentInsight: raw_text = " ".join( [str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()] ).strip() compact = re.sub(r"\s+", "", raw_text).lower() rule_match = _match_document_rule(compact) base_rule = rule_match.rule or DEFAULT_RULE fields = _apply_document_type_field_labels( tuple(_extract_document_fields(raw_text, base_rule.document_type)), base_rule.document_type, ) rule_insight = DocumentInsight( document_type=base_rule.document_type, document_type_label=base_rule.document_type_label, scene_code=base_rule.scene_code, scene_label=base_rule.scene_label, expense_type=base_rule.expense_type, fields=fields, classification_source="rule", classification_confidence=rule_match.confidence, evidence=rule_match.evidence, ) llm_result = self._classify_with_model( filename=str(filename or "").strip(), summary=str(summary or "").strip(), text=str(text or "").strip(), preview_data_url=str(preview_data_url or "").strip(), rule_insight=rule_insight, fields=fields, ) if llm_result is None: return rule_insight return self._merge_rule_and_model( rule_insight=rule_insight, llm_result=llm_result, fields=fields, has_preview=bool(preview_data_url), ) def _classify_with_model( self, *, filename: str, summary: str, text: str, preview_data_url: str, rule_insight: DocumentInsight, fields: tuple[DocumentField, ...], ) -> tuple[str, LlmDocumentClassification] | None: return None @staticmethod def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None: payload_json = _extract_json_payload(response_text) if payload_json is None: return None try: parsed = LlmDocumentClassification.model_validate(payload_json) except ValidationError: return None normalized_type = str(parsed.document_type or "other").strip().lower() or "other" if normalized_type not in SUPPORTED_DOCUMENT_TYPES: normalized_type = "other" base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE) evidence = [ str(item or "").strip() for item in parsed.evidence if str(item or "").strip() ][:4] normalized_fields = _apply_document_type_field_labels( tuple(_normalize_llm_document_fields(parsed.fields)), normalized_type, ) return LlmDocumentClassification( document_type=normalized_type, scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code, scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label, expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type, confidence=float(parsed.confidence), evidence=evidence, fields=normalized_fields, ) @staticmethod def _merge_rule_and_model( *, rule_insight: DocumentInsight, llm_result: tuple[str, LlmDocumentClassification], fields: tuple[DocumentField, ...], has_preview: bool, ) -> DocumentInsight: source, parsed = llm_result warnings = list(rule_insight.warnings) merged_fields = rule_insight.fields if parsed.fields and (has_preview or parsed.confidence >= 0.55): merged_fields = _merge_document_fields(rule_insight.fields, tuple(parsed.fields)) if merged_fields != rule_insight.fields: warnings.append("票据关键信息已结合大模型复核结果修正,建议人工再核对原图。") if parsed.confidence < 0.55: if merged_fields == rule_insight.fields: return rule_insight return DocumentInsight( document_type=rule_insight.document_type, document_type_label=rule_insight.document_type_label, scene_code=rule_insight.scene_code, scene_label=rule_insight.scene_label, expense_type=rule_insight.expense_type, fields=_apply_document_type_field_labels( merged_fields, rule_insight.document_type, ), classification_source=rule_insight.classification_source, classification_confidence=rule_insight.classification_confidence, evidence=rule_insight.evidence, warnings=tuple(warnings), ) should_override = False if parsed.document_type == rule_insight.document_type: should_override = True elif rule_insight.document_type == "other" and parsed.document_type != "other": should_override = True elif parsed.document_type != "other": threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12) should_override = parsed.confidence >= threshold if not should_override: if merged_fields == rule_insight.fields: return rule_insight return DocumentInsight( document_type=rule_insight.document_type, document_type_label=rule_insight.document_type_label, scene_code=rule_insight.scene_code, scene_label=rule_insight.scene_label, expense_type=rule_insight.expense_type, fields=_apply_document_type_field_labels( merged_fields, rule_insight.document_type, ), classification_source=rule_insight.classification_source, classification_confidence=rule_insight.classification_confidence, evidence=rule_insight.evidence, warnings=tuple(warnings), ) rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE) if parsed.document_type != rule_insight.document_type: warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。") return DocumentInsight( document_type=rule.document_type, document_type_label=rule.document_type_label, scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code, scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label, expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type, fields=_apply_document_type_field_labels(merged_fields, rule.document_type), classification_source=source, classification_confidence=max(parsed.confidence, rule_insight.classification_confidence), evidence=tuple(parsed.evidence or rule_insight.evidence), warnings=tuple(warnings), ) def build_document_insight( *, filename: str = "", summary: str = "", text: str = "", preview_data_url: str = "", ) -> DocumentInsight: return DocumentIntelligenceService().build_document_insight( filename=filename, summary=summary, text=text, preview_data_url=preview_data_url, ) def _match_document_rule(compact_text: str) -> RuleMatch: best_rule = DEFAULT_RULE best_evidence: tuple[str, ...] = () best_score = 0.0 for rule in DOCUMENT_RULES: matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text) if not matched: continue score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched) if score > best_score: best_rule = rule best_evidence = matched best_score = score if best_score <= 0: return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0) confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12) return RuleMatch( rule=best_rule, confidence=round(confidence, 2), evidence=best_evidence[:4], score=best_score, ) def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None: if not response_text: return None cleaned = re.sub(r".*?", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() if not cleaned: return None fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL) candidates = [fenced_match.group(1)] if fenced_match else [] candidates.append(cleaned) start = cleaned.find("{") end = cleaned.rfind("}") if start != -1 and end != -1 and end > start: candidates.append(cleaned[start : end + 1]) for candidate in candidates: try: parsed = json.loads(candidate) except json.JSONDecodeError: continue if isinstance(parsed, dict): return parsed return None def _normalize_llm_document_fields(raw_fields: list[DocumentField] | list[dict[str, Any]]) -> list[DocumentField]: normalized: list[DocumentField] = [] seen_keys: set[str] = set() for field in raw_fields: raw_key = str(getattr(field, "key", "") if isinstance(field, DocumentField) else field.get("key") or "").strip() raw_label = str(getattr(field, "label", "") if isinstance(field, DocumentField) else field.get("label") or "").strip() raw_value = str(getattr(field, "value", "") if isinstance(field, DocumentField) else field.get("value") or "").strip() key = _normalize_llm_document_field_key(raw_key, raw_label) if not key or key in seen_keys: continue value = _normalize_llm_document_field_value(key, raw_value) if not value: continue seen_keys.add(key) normalized.append( DocumentField( key=key, label=_llm_document_field_label(key), value=value, ) ) return normalized def _normalize_llm_document_field_key(key: str, label: str) -> str: compact_key = str(key or "").strip().lower() compact_label = str(label or "").replace(" ", "").lower() if compact_key in {"amount", "total_amount", "payment_amount", "paid_amount"} or any( token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额") ): return "amount" if compact_key in { "travel_date", "trip_date", "journey_date", "departure_date", "departure_time", "depart_date", "depart_time", "boarding_date", "boarding_time", "train_date", "train_time", "train_departure_time", "scheduled_departure_time", "flight_date", "flight_time", "ride_date", "ride_time", "pickup_time", "start_time", } or any( token in compact_label for token in ( "行程日期", "出发日期", "出发时间", "列车出发时间", "发车日期", "发车时间", "开车时间", "乘车日期", "乘车时间", "起飞日期", "航班日期", "上车时间", "用车时间", ) ): return "trip_date" if compact_key in {"issued_at", "issue_date", "invoice_date"} or "开票日期" in compact_label: return "invoice_date" if compact_key in {"date", "time"} or any( token in compact_label for token in ("日期", "时间", "发生时间") ): return "date" if compact_key in {"merchant_name", "merchant", "seller_name", "vendor_name"} or any( token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方") ): return "merchant_name" if compact_key in {"invoice_number", "ticket_number", "order_no", "order_number"} or any( token in compact_label for token in ("票据号码", "发票号码", "票号", "单号", "订单号") ): return "invoice_number" if compact_key in {"invoice_code"} or "发票代码" in compact_label: return "invoice_code" if compact_key in {"trip_no", "flight_no", "train_no"} or any( token in compact_label for token in ("车次", "航班") ): return "trip_no" if compact_key in {"route", "trip_route"} or any(token in compact_label for token in ("行程", "路线")): return "route" return "" def _normalize_llm_document_field_value(key: str, value: str) -> str: raw_value = str(value or "").strip() if not raw_value: return "" if key == "amount": amount = _extract_amount(raw_value) if amount: return amount cleaned = raw_value.replace("¥", "").replace("¥", "").replace("元", "").replace(",", ".").strip() try: candidate = Decimal(cleaned) except InvalidOperation: return "" if candidate <= Decimal("0.00"): return "" text_value = format(candidate.quantize(Decimal("0.01")), "f").rstrip("0").rstrip(".") return f"{text_value}元" if key in {"date", "time", "invoice_date", "trip_date"}: return _extract_date(raw_value) or _clean_field_value(raw_value) if key == "route": return _extract_route(raw_value) or _clean_field_value( raw_value.replace("→", "-").replace("至", "-").replace("->", "-") ) return _clean_field_value(raw_value) def _llm_document_field_label(key: str) -> str: return { "amount": "金额", "date": "日期", "invoice_date": "开票日期", "trip_date": "行程日期", "merchant_name": "商户", "invoice_number": "票据号码", "invoice_code": "发票代码", "trip_no": "车次/航班", "route": "行程", }.get(key, key) def _apply_document_type_field_labels( fields: tuple[DocumentField, ...], document_type: str, ) -> tuple[DocumentField, ...]: date_label = TRIP_DATE_LABEL_BY_DOCUMENT_TYPE.get( str(document_type or "").strip().lower() ) if not date_label: return fields adjusted: list[DocumentField] = [] for field in fields: compact_key = str(field.key or "").strip().lower() compact_label = str(field.label or "").replace(" ", "") if compact_key in {"issued_at", "issue_date", "invoice_date"} or any( token in compact_label for token in ("开票日期", "发票日期") ): adjusted.append(field) continue is_date_field = compact_key in DATE_FIELD_KEYS or any( token in compact_label for token in TRIP_DATE_FIELD_LABEL_TOKENS ) if is_date_field: adjusted.append(DocumentField(key=field.key, label=date_label, value=field.value)) continue adjusted.append(field) return tuple(adjusted) def _merge_document_fields( base_fields: tuple[DocumentField, ...], override_fields: tuple[DocumentField, ...], ) -> tuple[DocumentField, ...]: merged = {field.key: field for field in base_fields if field.key and field.value} order = [field.key for field in base_fields if field.key and field.value] for field in override_fields: if not field.key or not field.value: continue merged[field.key] = field if field.key not in order: order.append(field.key) return tuple(merged[key] for key in order if key in merged) def _extract_document_fields(text: str, document_type: str = "") -> list[DocumentField]: fields: list[DocumentField] = [] amount = _extract_amount(text) if amount: fields.append(DocumentField(key="amount", label="金额", value=amount)) date_value = _extract_date(text, document_type=document_type) if date_value: fields.append(DocumentField(key="date", label="日期", value=date_value)) merchant = _extract_merchant(text) if merchant: fields.append(DocumentField(key="merchant_name", label="商户", value=merchant)) invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text) if invoice_number: fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number)) invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text) if invoice_code: fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code)) trip_no = _extract_pattern(TRIP_NO_PATTERN, text) if trip_no: fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no)) route = _extract_route(text) if route: fields.append(DocumentField(key="route", label="行程", value=route)) return fields def _extract_amount(text: str) -> str: best_value: Decimal | None = None for pattern in AMOUNT_PATTERNS: for match in pattern.finditer(text): raw_value = str(match.group(1) or "").replace(",", ".").strip() try: candidate = Decimal(raw_value) except InvalidOperation: continue if candidate <= Decimal("0.00"): continue if _is_amount_match_date_fragment(candidate, text, match.start(1), match.end(1)): continue if best_value is None or candidate > best_value: best_value = candidate if best_value is None: return "" normalized = best_value.quantize(Decimal("0.01")) text_value = format(normalized, "f").rstrip("0").rstrip(".") return f"{text_value}元" def _is_amount_match_date_fragment(amount: Decimal, text: str, start: int, end: int) -> bool: if start < 0 or end < 0: return False normalized = amount.quantize(Decimal("0.01")) if normalized != normalized.to_integral_value() or normalized < Decimal("1900") or normalized > Decimal("2099"): return False before = str(text or "")[max(0, start - 8):start] after = str(text or "")[end:end + 10] if re.match(r"\s*(?:年|[-/.])\s*\d{1,2}", after): return True if re.search(r"\d{1,2}\s*(?:年|[-/.])\s*$", before): return True return False def _extract_date(text: str, *, document_type: str = "") -> str: matches = list(DATE_PATTERN.finditer(text)) if not matches: return "" normalized_type = str(document_type or "").strip().lower() if normalized_type in TRIP_DATE_LABEL_BY_DOCUMENT_TYPE: candidates: list[tuple[int, int, bool, str]] = [] for index, match in enumerate(matches): value = _format_date_match_with_time(text, match) if not value: continue invoice_context = _is_invoice_date_context(text, match) score = _score_trip_date_context(text, match, value, invoice_context) candidates.append((score, index, invoice_context, value)) non_invoice_candidates = [candidate for candidate in candidates if not candidate[2]] if non_invoice_candidates: return max(non_invoice_candidates, key=lambda candidate: (candidate[0], -candidate[1]))[3] if candidates: return "" return "" return _format_date_match_with_time(text, matches[0]) def _format_date_match_with_time(text: str, match: re.Match[str]) -> str: raw_value = str(match.group(1) or "").strip() normalized = raw_value.replace("年", "-").replace("月", "-").replace("日", "") normalized = normalized.replace("/", "-").replace(".", "-") parts = [part for part in normalized.split("-") if part] if len(parts) != 3: return raw_value year, month, day = parts date_value = f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}" surrounding = str(text or "")[max(0, match.start() - 18): match.end() + 24] time_match = TIME_PATTERN.search(surrounding) if time_match: hour = str(time_match.group(1) or "").zfill(2) minute = str(time_match.group(2) or "").zfill(2) return f"{date_value} {hour}:{minute}" return date_value def _is_invoice_date_context(text: str, match: re.Match[str]) -> bool: window = str(text or "")[max(0, match.start() - 12): match.end() + 8] compact = window.replace(" ", "") return any(token in compact for token in ("开票日期", "发票日期", "开票时间", "开票")) def _score_trip_date_context( text: str, match: re.Match[str], value: str, invoice_context: bool, ) -> int: window = str(text or "")[max(0, match.start() - 32): match.end() + 32] compact = window.replace(" ", "") score = -20 if invoice_context else 0 if ":" in value or ":" in value: score += 8 if any( token in compact for token in ( "行程日期", "出发日期", "出发时间", "列车出发时间", "发车日期", "发车时间", "开车时间", "乘车日期", "乘车时间", "起飞日期", "起飞时间", "航班日期", "上车时间", "用车时间", ) ): score += 6 if any(token in compact for token in ("车次", "检票", "二等座", "一等座", "商务座", "软卧", "硬卧")): score += 3 if re.search(r"[A-Z]\d{1,4}", compact): score += 2 if re.search(r"[\u4e00-\u9fa5A-Za-z0-9()()·]{2,20}(?:至|到|→|->|—|–|-)[\u4e00-\u9fa5A-Za-z0-9()()·]{2,20}", compact): score += 2 return score def _extract_merchant(text: str) -> str: for pattern in MERCHANT_PATTERNS: match = pattern.search(text) if not match: continue value = _clean_field_value(match.group(1)) if value: return value return "" def _extract_route(text: str) -> str: match = ROUTE_PATTERN.search(text) if not match: return "" start = _clean_field_value(match.group(1)) end = _clean_field_value(match.group(2)) if not start or not end or start == end: return "" return f"{start}-{end}" def _extract_pattern(pattern: re.Pattern[str], text: str) -> str: match = pattern.search(text) if not match: return "" return _clean_field_value(match.group(1)) def _clean_field_value(value: str) -> str: return str(value or "").strip().strip("::,,。.;;")