From 8b39f48dec41fc2ef0fe3d78487151b8f31dd159 Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Thu, 14 May 2026 09:32:15 +0000 Subject: [PATCH] =?UTF-8?q?feat(server):=20=E6=96=B0=E5=A2=9E=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E6=99=BA=E8=83=BD=E8=AF=86=E5=88=AB=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=EF=BC=8C=E6=89=A9=E5=B1=95OCR=E6=8E=A5=E5=8F=A3=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20Azure=20Document=20Intelligence?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/src/app/api/v1/endpoints/ocr.py | 6 +- server/src/app/schemas/ocr.py | 19 + .../src/app/services/document_intelligence.py | 582 ++++++++++++++++++ server/src/app/services/ocr.py | 388 ++++++++++-- server/tests/test_document_intelligence.py | 66 ++ server/tests/test_ocr_endpoints.py | 22 +- server/tests/test_ocr_service.py | 106 +++- 7 files changed, 1128 insertions(+), 61 deletions(-) create mode 100644 server/src/app/services/document_intelligence.py create mode 100644 server/tests/test_document_intelligence.py diff --git a/server/src/app/api/v1/endpoints/ocr.py b/server/src/app/api/v1/endpoints/ocr.py index cba8342..6f016d9 100644 --- a/server/src/app/api/v1/endpoints/ocr.py +++ b/server/src/app/api/v1/endpoints/ocr.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Annotated from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from sqlalchemy.orm import Session -from app.api.deps import CurrentUserContext, get_current_user +from app.api.deps import CurrentUserContext, get_current_user, get_db from app.schemas.common import ErrorResponse from app.schemas.ocr import OcrRecognizeBatchRead from app.services.ocr import OcrService @@ -35,6 +36,7 @@ router = APIRouter(prefix="/ocr") async def recognize_ocr_documents( files: Annotated[list[UploadFile], File(description="待识别的票据图片或 PDF。")], _: Annotated[CurrentUserContext, Depends(get_current_user)], + db: Annotated[Session, Depends(get_db)], ) -> OcrRecognizeBatchRead: try: payload = [] @@ -46,7 +48,7 @@ async def recognize_ocr_documents( upload.content_type, ) ) - return OcrService().recognize_files(payload) + return OcrService(db).recognize_files(payload) except ValueError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc except RuntimeError as exc: diff --git a/server/src/app/schemas/ocr.py b/server/src/app/schemas/ocr.py index 6038b23..4120214 100644 --- a/server/src/app/schemas/ocr.py +++ b/server/src/app/schemas/ocr.py @@ -10,6 +10,12 @@ class OcrRecognizeLineRead(BaseModel): page_index: int | None = Field(default=None, description="页码,从 0 开始。") +class OcrRecognizeFieldRead(BaseModel): + key: str = Field(description="结构化字段键。") + label: str = Field(description="结构化字段展示名。") + value: str = Field(default="", description="结构化字段值。") + + class OcrRecognizeDocumentRead(BaseModel): filename: str = Field(description="原始文件名。") media_type: str = Field(description="文件媒体类型。") @@ -20,6 +26,19 @@ class OcrRecognizeDocumentRead(BaseModel): avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="平均识别置信度。") line_count: int = Field(default=0, ge=0, description="文本行数。") page_count: int = Field(default=1, ge=0, description="识别页数。") + document_type: str = Field(default="other", description="识别出的票据类型编码。") + document_type_label: str = Field(default="其他单据", description="识别出的票据类型名称。") + scene_code: str = Field(default="other", description="识别出的票据场景编码。") + scene_label: str = Field(default="其他票据", description="识别出的票据场景名称。") + classification_source: str = Field(default="rule", description="票据类型判断来源,例如 rule / llm_text / llm_vision。") + classification_confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="票据类型判断置信度。") + classification_evidence: list[str] = Field(default_factory=list, description="票据类型判断依据摘要。") + document_fields: list[OcrRecognizeFieldRead] = Field( + default_factory=list, + description="识别出的结构化票据信息。", + ) + preview_kind: str = Field(default="", description="预览类型,PDF 转图后通常为 image。") + preview_data_url: str = Field(default="", description="用于前端展示的图片预览 data URL。") warnings: list[str] = Field(default_factory=list, description="该文件的识别提示或警告。") lines: list[OcrRecognizeLineRead] = Field(default_factory=list, description="逐行识别结果。") diff --git a/server/src/app/services/document_intelligence.py b/server/src/app/services/document_intelligence.py new file mode 100644 index 0000000..795c4c4 --- /dev/null +++ b/server/src/app/services/document_intelligence.py @@ -0,0 +1,582 @@ +from __future__ import annotations + +import json +import re +from dataclasses import dataclass +from decimal import Decimal, InvalidOperation +from typing import Any + +from pydantic import BaseModel, Field, ValidationError +from sqlalchemy.orm import Session + +from app.services.runtime_chat import RuntimeChatService + + +@dataclass(frozen=True, slots=True) +class DocumentField: + key: str + label: str + value: str + + +@dataclass(frozen=True, slots=True) +class DocumentInsight: + document_type: str + document_type_label: str + scene_code: str + scene_label: str + expense_type: str + fields: tuple[DocumentField, ...] = () + classification_source: str = "rule" + classification_confidence: float = 0.0 + evidence: tuple[str, ...] = () + warnings: tuple[str, ...] = () + + +@dataclass(frozen=True, slots=True) +class DocumentRule: + document_type: str + document_type_label: str + scene_code: str + scene_label: str + expense_type: str + keywords: tuple[str, ...] + score_bias: float = 0.0 + + +@dataclass(frozen=True, slots=True) +class RuleMatch: + rule: DocumentRule | None + confidence: float + evidence: tuple[str, ...] + score: float + + +class LlmDocumentClassification(BaseModel): + document_type: str = Field(default="other") + scene_code: str = Field(default="other") + scene_label: str = Field(default="其他票据") + expense_type: str = Field(default="other") + confidence: float = Field(default=0.0, ge=0.0, le=1.0) + evidence: list[str] = Field(default_factory=list) + + +DEFAULT_RULE = DocumentRule( + document_type="other", + document_type_label="其他单据", + scene_code="other", + scene_label="其他票据", + expense_type="other", + keywords=(), + score_bias=0.0, +) + +DOCUMENT_RULES: tuple[DocumentRule, ...] = ( + DocumentRule( + document_type="flight_itinerary", + document_type_label="机票/航班行程单", + scene_code="travel", + scene_label="差旅票据", + expense_type="travel", + keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"), + score_bias=0.34, + ), + DocumentRule( + document_type="train_ticket", + document_type_label="火车/高铁票", + scene_code="travel", + scene_label="差旅票据", + expense_type="travel", + keywords=("高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座"), + score_bias=0.32, + ), + DocumentRule( + document_type="hotel_invoice", + document_type_label="酒店住宿票据", + scene_code="hotel", + scene_label="住宿票据", + expense_type="hotel", + keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"), + score_bias=0.16, + ), + DocumentRule( + document_type="taxi_receipt", + document_type_label="出租车/网约车票据", + scene_code="transport", + scene_label="交通票据", + expense_type="transport", + keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"), + score_bias=0.38, + ), + DocumentRule( + document_type="parking_toll_receipt", + document_type_label="停车/通行费票据", + scene_code="transport", + scene_label="交通票据", + expense_type="transport", + keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"), + score_bias=0.28, + ), + DocumentRule( + document_type="meal_receipt", + document_type_label="餐饮票据", + scene_code="meal", + scene_label="餐饮票据", + expense_type="meal", + keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"), + score_bias=0.14, + ), + DocumentRule( + document_type="office_invoice", + document_type_label="办公用品票据", + scene_code="office", + scene_label="办公用品票据", + expense_type="office", + keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"), + score_bias=0.14, + ), + DocumentRule( + document_type="meeting_invoice", + document_type_label="会议/会务票据", + scene_code="meeting", + scene_label="会务票据", + expense_type="meeting", + keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"), + score_bias=0.12, + ), + DocumentRule( + document_type="training_invoice", + document_type_label="培训票据", + scene_code="training", + scene_label="培训票据", + expense_type="training", + keywords=("培训", "课程", "讲师", "教材", "学费", "认证"), + score_bias=0.12, + ), + DocumentRule( + document_type="vat_invoice", + document_type_label="增值税发票", + scene_code="other", + scene_label="通用发票", + expense_type="other", + keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"), + score_bias=-0.08, + ), + DocumentRule( + document_type="receipt", + document_type_label="一般收据/凭证", + scene_code="other", + scene_label="其他票据", + expense_type="other", + keywords=("收据", "凭证", "票据"), + score_bias=-0.18, + ), +) + +DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES} +SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",) + +AMOUNT_PATTERNS = ( + re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[::\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"), + re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"), +) +DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)") +INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[::\s]*([A-Za-z0-9-]{6,24})") +INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[::\s]*([A-Za-z0-9-]{6,24})") +TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[::\s]*([A-Za-z0-9]{2,12})") +ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})") +MERCHANT_PATTERNS = ( + re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[::\s]*([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40})"), + re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"), +) + + +class DocumentIntelligenceService: + def __init__(self, db: Session | None = None) -> None: + self.runtime_chat_service = RuntimeChatService(db) if db is not None else None + + def build_document_insight( + self, + *, + filename: str = "", + summary: str = "", + text: str = "", + preview_data_url: str = "", + ) -> DocumentInsight: + raw_text = " ".join( + [str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()] + ).strip() + compact = re.sub(r"\s+", "", raw_text).lower() + rule_match = _match_document_rule(compact) + base_rule = rule_match.rule or DEFAULT_RULE + fields = tuple(_extract_document_fields(raw_text)) + rule_insight = DocumentInsight( + document_type=base_rule.document_type, + document_type_label=base_rule.document_type_label, + scene_code=base_rule.scene_code, + scene_label=base_rule.scene_label, + expense_type=base_rule.expense_type, + fields=fields, + classification_source="rule", + classification_confidence=rule_match.confidence, + evidence=rule_match.evidence, + ) + + llm_result = self._classify_with_model( + filename=str(filename or "").strip(), + summary=str(summary or "").strip(), + text=str(text or "").strip(), + preview_data_url=str(preview_data_url or "").strip(), + rule_insight=rule_insight, + fields=fields, + ) + if llm_result is None: + return rule_insight + return self._merge_rule_and_model( + rule_insight=rule_insight, + llm_result=llm_result, + fields=fields, + has_preview=bool(preview_data_url), + ) + + def _classify_with_model( + self, + *, + filename: str, + summary: str, + text: str, + preview_data_url: str, + rule_insight: DocumentInsight, + fields: tuple[DocumentField, ...], + ) -> tuple[str, LlmDocumentClassification] | None: + if self.runtime_chat_service is None: + return None + + trimmed_text = text.strip() + if not trimmed_text and not summary.strip(): + return None + + facts = { + "filename": filename, + "summary": summary[:300], + "ocr_text_excerpt": trimmed_text[:2000], + "rule_candidate": { + "document_type": rule_insight.document_type, + "document_type_label": rule_insight.document_type_label, + "scene_code": rule_insight.scene_code, + "scene_label": rule_insight.scene_label, + "expense_type": rule_insight.expense_type, + "confidence": round(rule_insight.classification_confidence, 2), + "evidence": list(rule_insight.evidence), + }, + "extracted_fields": [ + {"key": field.key, "label": field.label, "value": field.value} + for field in fields + ], + "allowed_document_types": list(SUPPORTED_DOCUMENT_TYPES), + } + + system_prompt = ( + "你是企业报销票据识别复核器。" + "你的任务不是 OCR,而是在已有 OCR 文本和票据预览基础上判断票据类型。" + "只输出 JSON 对象,不要输出 Markdown、解释或代码块。" + "document_type 只能是:" + f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}。" + "如果证据不足,返回 other。" + "严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。" + "如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。" + "例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。" + "输出字段:document_type, scene_code, scene_label, expense_type, confidence, evidence。" + ) + user_prompt = ( + "请根据以下票据事实给出最终分类 JSON:\n" + f"{json.dumps(facts, ensure_ascii=False, indent=2)}\n\n" + "示例输出:\n" + "{\n" + ' "document_type": "taxi_receipt",\n' + ' "scene_code": "transport",\n' + ' "scene_label": "交通票据",\n' + ' "expense_type": "transport",\n' + ' "confidence": 0.86,\n' + ' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n' + "}" + ) + + if preview_data_url: + response_text = self.runtime_chat_service.complete( + [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": user_prompt}, + {"type": "image_url", "image_url": {"url": preview_data_url}}, + ], + }, + ], + slot_priority=("vlm",), + max_tokens=320, + temperature=0.0, + ) + parsed = self._parse_llm_payload(response_text) + if parsed is not None: + return "llm_vision", parsed + + response_text = self.runtime_chat_service.complete( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + slot_priority=("main", "backup"), + max_tokens=320, + temperature=0.0, + ) + parsed = self._parse_llm_payload(response_text) + if parsed is not None: + return "llm_text", parsed + return None + + @staticmethod + def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None: + payload_json = _extract_json_payload(response_text) + if payload_json is None: + return None + + try: + parsed = LlmDocumentClassification.model_validate(payload_json) + except ValidationError: + return None + + normalized_type = str(parsed.document_type or "other").strip().lower() or "other" + if normalized_type not in SUPPORTED_DOCUMENT_TYPES: + normalized_type = "other" + + base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE) + evidence = [ + str(item or "").strip() + for item in parsed.evidence + if str(item or "").strip() + ][:4] + + return LlmDocumentClassification( + document_type=normalized_type, + scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code, + scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label, + expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type, + confidence=float(parsed.confidence), + evidence=evidence, + ) + + @staticmethod + def _merge_rule_and_model( + *, + rule_insight: DocumentInsight, + llm_result: tuple[str, LlmDocumentClassification], + fields: tuple[DocumentField, ...], + has_preview: bool, + ) -> DocumentInsight: + source, parsed = llm_result + if parsed.confidence < 0.55: + return rule_insight + + should_override = False + if parsed.document_type == rule_insight.document_type: + should_override = True + elif rule_insight.document_type == "other" and parsed.document_type != "other": + should_override = True + elif parsed.document_type != "other": + threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12) + should_override = parsed.confidence >= threshold + + if not should_override: + return rule_insight + + rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE) + warnings = list(rule_insight.warnings) + if parsed.document_type != rule_insight.document_type: + warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。") + + return DocumentInsight( + document_type=rule.document_type, + document_type_label=rule.document_type_label, + scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code, + scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label, + expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type, + fields=fields, + classification_source=source, + classification_confidence=max(parsed.confidence, rule_insight.classification_confidence), + evidence=tuple(parsed.evidence or rule_insight.evidence), + warnings=tuple(warnings), + ) + + +def build_document_insight( + *, + filename: str = "", + summary: str = "", + text: str = "", + preview_data_url: str = "", +) -> DocumentInsight: + return DocumentIntelligenceService().build_document_insight( + filename=filename, + summary=summary, + text=text, + preview_data_url=preview_data_url, + ) + + +def _match_document_rule(compact_text: str) -> RuleMatch: + best_rule = DEFAULT_RULE + best_evidence: tuple[str, ...] = () + best_score = 0.0 + + for rule in DOCUMENT_RULES: + matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text) + if not matched: + continue + score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched) + if score > best_score: + best_rule = rule + best_evidence = matched + best_score = score + + if best_score <= 0: + return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0) + + confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12) + return RuleMatch( + rule=best_rule, + confidence=round(confidence, 2), + evidence=best_evidence[:4], + score=best_score, + ) + + +def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None: + if not response_text: + return None + + cleaned = re.sub(r".*?", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() + if not cleaned: + return None + + fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL) + candidates = [fenced_match.group(1)] if fenced_match else [] + candidates.append(cleaned) + + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end != -1 and end > start: + candidates.append(cleaned[start : end + 1]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + return parsed + return None + + +def _extract_document_fields(text: str) -> list[DocumentField]: + fields: list[DocumentField] = [] + amount = _extract_amount(text) + if amount: + fields.append(DocumentField(key="amount", label="金额", value=amount)) + + date_value = _extract_date(text) + if date_value: + fields.append(DocumentField(key="date", label="日期", value=date_value)) + + merchant = _extract_merchant(text) + if merchant: + fields.append(DocumentField(key="merchant_name", label="商户", value=merchant)) + + invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text) + if invoice_number: + fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number)) + + invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text) + if invoice_code: + fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code)) + + trip_no = _extract_pattern(TRIP_NO_PATTERN, text) + if trip_no: + fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no)) + + route = _extract_route(text) + if route: + fields.append(DocumentField(key="route", label="行程", value=route)) + + return fields + + +def _extract_amount(text: str) -> str: + best_value: Decimal | None = None + for pattern in AMOUNT_PATTERNS: + for match in pattern.finditer(text): + raw_value = str(match.group(1) or "").replace(",", ".").strip() + try: + candidate = Decimal(raw_value) + except InvalidOperation: + continue + if candidate <= Decimal("0.00"): + continue + if best_value is None or candidate > best_value: + best_value = candidate + if best_value is not None: + break + + if best_value is None: + return "" + normalized = best_value.quantize(Decimal("0.01")) + text_value = format(normalized, "f").rstrip("0").rstrip(".") + return f"{text_value}元" + + +def _extract_date(text: str) -> str: + match = DATE_PATTERN.search(text) + if not match: + return "" + raw_value = str(match.group(1) or "").strip() + normalized = raw_value.replace("年", "-").replace("月", "-").replace("日", "") + normalized = normalized.replace("/", "-").replace(".", "-") + parts = [part for part in normalized.split("-") if part] + if len(parts) != 3: + return raw_value + year, month, day = parts + return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}" + + +def _extract_merchant(text: str) -> str: + for pattern in MERCHANT_PATTERNS: + match = pattern.search(text) + if not match: + continue + value = _clean_field_value(match.group(1)) + if value: + return value + return "" + + +def _extract_route(text: str) -> str: + match = ROUTE_PATTERN.search(text) + if not match: + return "" + start = _clean_field_value(match.group(1)) + end = _clean_field_value(match.group(2)) + if not start or not end or start == end: + return "" + return f"{start}-{end}" + + +def _extract_pattern(pattern: re.Pattern[str], text: str) -> str: + match = pattern.search(text) + if not match: + return "" + return _clean_field_value(match.group(1)) + + +def _clean_field_value(value: str) -> str: + return str(value or "").strip().strip("::,,。.;;") diff --git a/server/src/app/services/ocr.py b/server/src/app/services/ocr.py index 6a5b9b2..c2e38cf 100644 --- a/server/src/app/services/ocr.py +++ b/server/src/app/services/ocr.py @@ -1,21 +1,55 @@ from __future__ import annotations +import base64 import json import shutil import subprocess +from dataclasses import dataclass, field from pathlib import Path from uuid import uuid4 +from sqlalchemy.orm import Session + from app.core.config import SERVER_DIR, get_settings -from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead +from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead +from app.services.document_intelligence import DocumentIntelligenceService WORKER_JSON_PREFIX = "__OCR_JSON__=" SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"} +@dataclass(slots=True) +class PreparedOcrInput: + input_path: Path + source_key: str + filename: str + media_type: str + page_index: int | None = None + preview_kind: str = "" + preview_data_url: str = "" + + +@dataclass(slots=True) +class AggregatedOcrDocument: + filename: str + media_type: str + source_key: str + engine: str = "paddleocr_mobile" + model: str = "PP-OCRv5_mobile" + summary_fragments: list[str] = field(default_factory=list) + text_fragments: list[str] = field(default_factory=list) + score_values: list[float] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + lines: list[OcrRecognizeLineRead] = field(default_factory=list) + page_count: int = 0 + preview_kind: str = "" + preview_data_url: str = "" + + class OcrService: - def __init__(self) -> None: + def __init__(self, db: Session | None = None) -> None: self.settings = get_settings() + self.document_intelligence_service = DocumentIntelligenceService(db) def recognize_files( self, @@ -28,10 +62,11 @@ class OcrService: temp_root.mkdir(parents=True, exist_ok=True) documents: list[OcrRecognizeDocumentRead] = [] - input_paths: list[Path] = [] - meta_by_path: dict[str, tuple[str, str]] = {} + prepared_inputs: list[PreparedOcrInput] = [] + cleanup_paths: list[Path] = [] python_bin = self._resolve_python_bin() worker_path = self._resolve_worker_path() + worker_payload: dict = {} try: for filename, content, media_type in files: @@ -73,17 +108,55 @@ class OcrService: temp_path = temp_root / f"{uuid4().hex}{suffix}" temp_path.write_bytes(content) - input_paths.append(temp_path) - meta_by_path[str(temp_path)] = (normalized_name, resolved_media_type) + cleanup_paths.append(temp_path) - if input_paths: + if suffix == ".pdf": + try: + prepared_inputs.extend( + self._prepare_pdf_inputs( + pdf_path=temp_path, + filename=normalized_name, + media_type=resolved_media_type, + cleanup_paths=cleanup_paths, + ) + ) + except RuntimeError as exc: + documents.append( + OcrRecognizeDocumentRead( + filename=normalized_name, + media_type=resolved_media_type, + warnings=[str(exc)], + ) + ) + continue + + prepared_inputs.append( + PreparedOcrInput( + input_path=temp_path, + source_key=uuid4().hex, + filename=normalized_name, + media_type=resolved_media_type, + preview_kind="image" if resolved_media_type.startswith("image/") else "", + preview_data_url=( + self._build_preview_data_url(temp_path, media_type=resolved_media_type) + if resolved_media_type.startswith("image/") + else "" + ), + ) + ) + + if prepared_inputs: worker_payload = self._invoke_worker( python_bin=python_bin, worker_path=worker_path, - input_paths=input_paths, + input_paths=[item.input_path for item in prepared_inputs], + ) + documents.extend( + self._build_documents( + worker_documents=worker_payload.get("documents", []), + prepared_inputs=prepared_inputs, + ) ) - for item in worker_payload.get("documents", []): - documents.append(self._build_document(item, meta_by_path)) success_count = sum( 1 @@ -92,12 +165,12 @@ class OcrService: ) engine = ( str(worker_payload.get("engine", "paddleocr_mobile")) - if input_paths + if prepared_inputs else "paddleocr_mobile" ) model = ( str(worker_payload.get("model", "PP-OCRv5_mobile")) - if input_paths + if prepared_inputs else "PP-OCRv5_mobile" ) return OcrRecognizeBatchRead( @@ -108,8 +181,7 @@ class OcrService: documents=documents, ) finally: - for path in input_paths: - path.unlink(missing_ok=True) + self._cleanup_temp_paths(cleanup_paths) def _resolve_python_bin(self) -> str: candidates = [] @@ -182,40 +254,258 @@ class OcrService: return json.loads(normalized[len(WORKER_JSON_PREFIX) :]) return None - @staticmethod - def _build_document( - payload: dict, - meta_by_path: dict[str, tuple[str, str]], - ) -> OcrRecognizeDocumentRead: - input_path = str(payload.get("input_path") or "") - filename, media_type = meta_by_path.get( - input_path, - (Path(input_path).name or "upload.bin", "application/octet-stream"), - ) - lines = [ - OcrRecognizeLineRead( - text=str(item.get("text", "")), - score=float(item.get("score", 0.0) or 0.0), - box=[ - [int(point[0]), int(point[1])] - for point in item.get("box", []) - if isinstance(point, list) and len(point) == 2 - ], - page_index=int(item["page_index"]) if item.get("page_index") is not None else None, + def _prepare_pdf_inputs( + self, + *, + pdf_path: Path, + filename: str, + media_type: str, + cleanup_paths: list[Path], + ) -> list[PreparedOcrInput]: + output_dir = pdf_path.with_suffix("") + output_dir.mkdir(parents=True, exist_ok=True) + cleanup_paths.append(output_dir) + + image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir) + if not image_paths: + raise RuntimeError("PDF 转图片后未生成可识别页面。") + + preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png") + source_key = uuid4().hex + descriptors: list[PreparedOcrInput] = [] + for page_index, image_path in enumerate(image_paths): + descriptors.append( + PreparedOcrInput( + input_path=image_path, + source_key=source_key, + filename=filename, + media_type=media_type, + page_index=page_index, + preview_kind="image" if page_index == 0 else "", + preview_data_url=preview_data_url if page_index == 0 else "", + ) ) - for item in payload.get("lines", []) - if isinstance(item, dict) - ] - return OcrRecognizeDocumentRead( - filename=filename, - media_type=media_type, - engine=str(payload.get("engine", "paddleocr_mobile")), - model=str(payload.get("model", "PP-OCRv5_mobile")), - text=str(payload.get("text", "")), - summary=str(payload.get("summary", "")), - avg_score=float(payload.get("avg_score", 0.0) or 0.0), - line_count=int(payload.get("line_count", len(lines)) or 0), - page_count=int(payload.get("page_count", 1) or 1), - warnings=[str(item) for item in payload.get("warnings", [])], - lines=lines, + return descriptors + + def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]: + prefix = output_dir / "page" + completed = subprocess.run( + [ + "pdftoppm", + "-png", + "-r", + "160", + str(pdf_path), + str(prefix), + ], + capture_output=True, + text=True, + timeout=self.settings.ocr_timeout_seconds, + check=False, ) + if completed.returncode != 0: + detail = (completed.stderr or completed.stdout or "").strip() + raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}") + + return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key) + + @staticmethod + def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]: + suffix = path.stem.rsplit("-", 1)[-1] + try: + return int(suffix), path.name + except ValueError: + return 0, path.name + + @staticmethod + def _build_preview_data_url(path: Path, *, media_type: str) -> str: + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + return f"data:{media_type};base64,{encoded}" + + def _build_documents( + self, + *, + worker_documents: list[dict], + prepared_inputs: list[PreparedOcrInput], + ) -> list[OcrRecognizeDocumentRead]: + descriptor_by_path = {str(item.input_path): item for item in prepared_inputs} + source_order: list[str] = [] + seen_sources: set[str] = set() + for item in prepared_inputs: + if item.source_key in seen_sources: + continue + seen_sources.add(item.source_key) + source_order.append(item.source_key) + + aggregated_by_source: dict[str, AggregatedOcrDocument] = {} + for payload in worker_documents: + if not isinstance(payload, dict): + continue + input_path = str(payload.get("input_path") or "") + descriptor = descriptor_by_path.get(input_path) + if descriptor is None: + continue + + aggregated = aggregated_by_source.get(descriptor.source_key) + if aggregated is None: + aggregated = AggregatedOcrDocument( + filename=descriptor.filename, + media_type=descriptor.media_type, + source_key=descriptor.source_key, + engine=str(payload.get("engine", "paddleocr_mobile")), + model=str(payload.get("model", "PP-OCRv5_mobile")), + ) + aggregated_by_source[descriptor.source_key] = aggregated + + aggregated.page_count = max( + aggregated.page_count, + (descriptor.page_index + 1) + if descriptor.page_index is not None + else int(payload.get("page_count", 1) or 1), + ) + if descriptor.preview_kind and not aggregated.preview_kind: + aggregated.preview_kind = descriptor.preview_kind + if descriptor.preview_data_url and not aggregated.preview_data_url: + aggregated.preview_data_url = descriptor.preview_data_url + + page_summary = str(payload.get("summary", "") or "").strip() + if page_summary: + aggregated.summary_fragments.append(page_summary) + + page_text = str(payload.get("text", "") or "").strip() + if page_text: + aggregated.text_fragments.append(page_text) + + lines = self._build_lines( + payload.get("lines", []), + page_index_override=descriptor.page_index, + ) + aggregated.lines.extend(lines) + aggregated.score_values.extend(line.score for line in lines if line.score > 0) + + if not lines: + avg_score = float(payload.get("avg_score", 0.0) or 0.0) + if avg_score > 0: + aggregated.score_values.append(avg_score) + + for warning in payload.get("warnings", []): + normalized_warning = str(warning or "").strip() + if normalized_warning and normalized_warning not in aggregated.warnings: + aggregated.warnings.append(normalized_warning) + + documents: list[OcrRecognizeDocumentRead] = [] + for source_key in source_order: + descriptors = [item for item in prepared_inputs if item.source_key == source_key] + if not descriptors: + continue + aggregated = aggregated_by_source.get(source_key) + if aggregated is None: + first_descriptor = descriptors[0] + documents.append( + OcrRecognizeDocumentRead( + filename=first_descriptor.filename, + media_type=first_descriptor.media_type, + page_count=max(1, len(descriptors)), + preview_kind=first_descriptor.preview_kind, + preview_data_url=first_descriptor.preview_data_url, + warnings=["OCR worker 未返回该文件的识别结果。"], + ) + ) + continue + documents.append(self._finalize_document(aggregated)) + + return documents + + @staticmethod + def _build_lines( + items: list[dict], + *, + page_index_override: int | None = None, + ) -> list[OcrRecognizeLineRead]: + lines: list[OcrRecognizeLineRead] = [] + for item in items: + if not isinstance(item, dict): + continue + page_index = page_index_override + if page_index is None and item.get("page_index") is not None: + page_index = int(item["page_index"]) + lines.append( + OcrRecognizeLineRead( + text=str(item.get("text", "")), + score=float(item.get("score", 0.0) or 0.0), + box=[ + [int(point[0]), int(point[1])] + for point in item.get("box", []) + if isinstance(point, list) and len(point) == 2 + ], + page_index=page_index, + ) + ) + return lines + + @staticmethod + def _truncate_summary(parts: list[str]) -> str: + summary = ";".join([part for part in parts if part][:3]) + if len(summary) > 180: + return f"{summary[:177]}..." + return summary + + def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead: + full_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip() + summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments) + insight = self.document_intelligence_service.build_document_insight( + filename=aggregated.filename, + summary=summary, + text=full_text, + preview_data_url=aggregated.preview_data_url, + ) + warnings = list(aggregated.warnings) + for warning in insight.warnings: + normalized_warning = str(warning or "").strip() + if normalized_warning and normalized_warning not in warnings: + warnings.append(normalized_warning) + return OcrRecognizeDocumentRead( + filename=aggregated.filename, + media_type=aggregated.media_type, + engine=aggregated.engine, + model=aggregated.model, + text=full_text, + summary=summary, + avg_score=( + sum(aggregated.score_values) / len(aggregated.score_values) + if aggregated.score_values + else 0.0 + ), + line_count=len(aggregated.lines), + page_count=max(1, aggregated.page_count), + document_type=insight.document_type, + document_type_label=insight.document_type_label, + scene_code=insight.scene_code, + scene_label=insight.scene_label, + classification_source=insight.classification_source, + classification_confidence=insight.classification_confidence, + classification_evidence=list(insight.evidence), + document_fields=[ + OcrRecognizeFieldRead( + key=field.key, + label=field.label, + value=field.value, + ) + for field in insight.fields + ], + preview_kind=aggregated.preview_kind, + preview_data_url=aggregated.preview_data_url, + warnings=warnings, + lines=sorted( + aggregated.lines, + key=lambda item: item.page_index if item.page_index is not None else -1, + ), + ) + + @staticmethod + def _cleanup_temp_paths(paths: list[Path]) -> None: + for path in reversed(paths): + if path.is_dir(): + shutil.rmtree(path, ignore_errors=True) + continue + path.unlink(missing_ok=True) diff --git a/server/tests/test_document_intelligence.py b/server/tests/test_document_intelligence.py new file mode 100644 index 0000000..7d57e2d --- /dev/null +++ b/server/tests/test_document_intelligence.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import json + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from app.services.document_intelligence import DocumentIntelligenceService, build_document_insight +from app.services.runtime_chat import RuntimeChatService + + +def test_build_document_insight_prefers_transport_for_didi_text_with_hotel_noise() -> None: + insight = build_document_insight( + filename="didi-trip.png", + summary="滴滴出行行程单", + text="滴滴出行电子发票 订单号 12345678 上车点 深圳湾 下车点 后海 全季酒店 里程 12.4 公里 金额 48 元", + ) + + assert insight.document_type == "taxi_receipt" + assert insight.document_type_label == "出租车/网约车票据" + assert insight.scene_code == "transport" + assert any(field.label == "金额" and field.value == "48元" for field in insight.fields) + + +def test_document_intelligence_service_uses_vlm_result_when_preview_available(monkeypatch) -> None: + calls: list[tuple[str, ...]] = [] + + def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2): + calls.append(slot_priority) + if slot_priority == ("vlm",): + assert isinstance(messages[1]["content"], list) + return json.dumps( + { + "document_type": "taxi_receipt", + "scene_code": "transport", + "scene_label": "交通票据", + "expense_type": "transport", + "confidence": 0.91, + "evidence": ["图片主体为滴滴行程单,OCR 中出现订单号、上车、下车等字段"], + }, + ensure_ascii=False, + ) + return None + + monkeypatch.setattr(RuntimeChatService, "complete", fake_complete) + + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + session = sessionmaker(bind=engine, autoflush=False, autocommit=False)() + try: + insight = DocumentIntelligenceService(session).build_document_insight( + filename="mixed-noise.png", + summary="OCR 混入酒店名称", + text="全季酒店 滴滴出行 订单号 12345678 上车 下车 金额 52 元", + preview_data_url="data:image/png;base64,ZmFrZQ==", + ) + finally: + session.close() + + assert insight.document_type == "taxi_receipt" + assert insight.classification_source == "llm_vision" + assert calls[0] == ("vlm",) diff --git a/server/tests/test_ocr_endpoints.py b/server/tests/test_ocr_endpoints.py index 0a1b0b7..14b9c82 100644 --- a/server/tests/test_ocr_endpoints.py +++ b/server/tests/test_ocr_endpoints.py @@ -10,7 +10,7 @@ from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.db.base import Base from app.main import create_app -from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead +from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead from app.services.ocr import OcrService @@ -50,14 +50,23 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None: OcrRecognizeDocumentRead( filename="invoice.png", media_type="image/png", - text="发票金额 100 元", - summary="发票金额 100 元", + text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13", + summary="增值税电子发票,金额 100 元。", avg_score=0.98, line_count=1, page_count=1, + document_type="vat_invoice", + document_type_label="增值税发票", + scene_code="other", + scene_label="通用发票", + document_fields=[ + OcrRecognizeFieldRead(key="amount", label="金额", value="100元"), + OcrRecognizeFieldRead(key="date", label="日期", value="2026-05-13"), + OcrRecognizeFieldRead(key="invoice_number", label="票据号码", value="12345678"), + ], lines=[ OcrRecognizeLineRead( - text="发票金额 100 元", + text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13", score=0.98, box=[[1, 2], [10, 2], [10, 8], [1, 8]], page_index=0, @@ -81,4 +90,7 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None: assert payload["engine"] == "paddleocr_mobile" assert payload["success_count"] == 1 assert payload["documents"][0]["filename"] == "invoice.png" - assert payload["documents"][0]["summary"] == "发票金额 100 元" + assert payload["documents"][0]["summary"] == "增值税电子发票,金额 100 元。" + assert payload["documents"][0]["document_type"] == "vat_invoice" + assert payload["documents"][0]["document_type_label"] == "增值税发票" + assert payload["documents"][0]["document_fields"][0]["label"] == "金额" diff --git a/server/tests/test_ocr_service.py b/server/tests/test_ocr_service.py index 37cba1f..8141050 100644 --- a/server/tests/test_ocr_service.py +++ b/server/tests/test_ocr_service.py @@ -26,15 +26,15 @@ for index, arg in enumerate(sys.argv): "input_path": input_path, "engine": "paddleocr_mobile", "model": "PP-OCRv5_mobile", - "text": "发票金额 100 元", - "summary": "发票金额 100 元", + "text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13", + "summary": "增值税电子发票,金额 100 元。", "avg_score": 0.98, "line_count": 1, "page_count": 1, "warnings": [], "lines": [ { - "text": "发票金额 100 元", + "text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13", "score": 0.98, "box": [[1, 2], [10, 2], [10, 8], [1, 8]], "page_index": 0, @@ -74,10 +74,106 @@ print("__OCR_JSON__=" + json.dumps(payload, ensure_ascii=False)) assert len(result.documents) == 2 recognized = next(item for item in result.documents if item.filename == "invoice.png") - assert recognized.summary == "发票金额 100 元" + assert recognized.summary == "增值税电子发票,金额 100 元。" assert recognized.line_count == 1 - assert recognized.lines[0].text == "发票金额 100 元" + assert recognized.document_type == "vat_invoice" + assert recognized.document_type_label == "增值税发票" + assert any(field.label == "金额" and field.value == "100元" for field in recognized.document_fields) + assert any(field.label == "票据号码" and field.value == "12345678" for field in recognized.document_fields) + assert any(field.label == "日期" and field.value == "2026-05-13" for field in recognized.document_fields) + assert recognized.lines[0].text == "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13" skipped = next(item for item in result.documents if item.filename == "notes.txt") assert skipped.line_count == 0 assert skipped.warnings == ["当前仅支持图片和 PDF 文件进行 OCR。"] + + +def test_ocr_service_converts_pdf_to_images_and_returns_image_preview( + monkeypatch, + tmp_path: Path, +) -> None: + def fake_convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]: + first = output_dir / "page-1.png" + second = output_dir / "page-2.png" + first.write_bytes(b"fake-page-1") + second.write_bytes(b"fake-page-2") + return [first, second] + + def fake_invoke_worker( + self, + *, + python_bin: str, + worker_path: str, + input_paths: list[Path], + ) -> dict: + assert [path.name for path in input_paths] == ["page-1.png", "page-2.png"] + return { + "engine": "paddleocr_mobile", + "model": "PP-OCRv5_mobile", + "documents": [ + { + "input_path": str(input_paths[0]), + "engine": "paddleocr_mobile", + "model": "PP-OCRv5_mobile", + "text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元", + "summary": "高铁票第一页", + "avg_score": 0.97, + "line_count": 1, + "page_count": 1, + "warnings": [], + "lines": [ + { + "text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元", + "score": 0.97, + "box": [[1, 2], [10, 2], [10, 8], [1, 8]], + } + ], + }, + { + "input_path": str(input_paths[1]), + "engine": "paddleocr_mobile", + "model": "PP-OCRv5_mobile", + "text": "乘车人 张三", + "summary": "高铁票第二页", + "avg_score": 0.94, + "line_count": 1, + "page_count": 1, + "warnings": [], + "lines": [ + { + "text": "乘车人 张三", + "score": 0.94, + "box": [[1, 2], [10, 2], [10, 8], [1, 8]], + } + ], + }, + ], + } + + monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage")) + monkeypatch.setattr(OcrService, "_resolve_python_bin", lambda self: "python") + monkeypatch.setattr(OcrService, "_resolve_worker_path", lambda self: "worker.py") + monkeypatch.setattr(OcrService, "_convert_pdf_to_images", fake_convert_pdf_to_images) + monkeypatch.setattr(OcrService, "_invoke_worker", fake_invoke_worker) + get_settings.cache_clear() + try: + result = OcrService().recognize_files( + [ + ("train-ticket.pdf", b"%PDF-1.4 fake", "application/pdf"), + ] + ) + finally: + get_settings.cache_clear() + + assert result.success_count == 1 + assert len(result.documents) == 1 + recognized = result.documents[0] + assert recognized.filename == "train-ticket.pdf" + assert recognized.page_count == 2 + assert recognized.preview_kind == "image" + assert recognized.preview_data_url.startswith("data:image/png;base64,") + assert recognized.document_type == "train_ticket" + assert any(field.label == "金额" and field.value == "188元" for field in recognized.document_fields) + assert any(field.label == "车次/航班" and field.value == "G1234" for field in recognized.document_fields) + assert recognized.lines[0].page_index == 0 + assert recognized.lines[1].page_index == 1