from __future__ import annotations import base64 import hashlib import json import re import shutil import subprocess import threading from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path from uuid import uuid4 from sqlalchemy.orm import Session from app.core.config import SERVER_DIR, get_settings from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead from app.services.document_intelligence import DocumentIntelligenceService WORKER_JSON_PREFIX = "__OCR_JSON__=" SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"} OCR_RESULT_CACHE_LIMIT = 32 @dataclass(slots=True) class PreparedOcrInput: input_path: Path source_key: str filename: str media_type: str page_index: int | None = None preview_kind: str = "" preview_data_url: str = "" text_layer: str = "" @dataclass(slots=True) class AggregatedOcrDocument: filename: str media_type: str source_key: str engine: str = "paddleocr_mobile" model: str = "PP-OCRv5_mobile" summary_fragments: list[str] = field(default_factory=list) text_fragments: list[str] = field(default_factory=list) text_layer_fragments: list[str] = field(default_factory=list) score_values: list[float] = field(default_factory=list) warnings: list[str] = field(default_factory=list) lines: list[OcrRecognizeLineRead] = field(default_factory=list) page_count: int = 0 preview_kind: str = "" preview_data_url: str = "" class OcrService: _cache_lock = threading.Lock() _result_cache: OrderedDict[str, OcrRecognizeDocumentRead] = OrderedDict() _worker_semaphore_lock = threading.Lock() _worker_semaphore: threading.Semaphore | None = None _worker_semaphore_limit = 0 def __init__(self, db: Session | None = None) -> None: self.settings = get_settings() self.document_intelligence_service = DocumentIntelligenceService(db) def recognize_files( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: if not files: raise ValueError("至少需要上传一个文件。") temp_root = self.settings.resolved_ocr_temp_dir temp_root.mkdir(parents=True, exist_ok=True) documents: list[OcrRecognizeDocumentRead] = [] prepared_inputs: list[PreparedOcrInput] = [] cleanup_paths: list[Path] = [] worker_payload: dict = {} cache_keys_by_source: dict[str, str] = {} try: for filename, content, media_type in files: normalized_name = Path(str(filename or "").strip()).name or "upload.bin" suffix = Path(normalized_name).suffix.lower() resolved_media_type = str(media_type or "application/octet-stream") if not content: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=["文件内容为空,未执行 OCR。"], ) ) continue if suffix not in SUPPORTED_SUFFIXES: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=["当前仅支持图片和 PDF 文件进行 OCR。"], ) ) continue if len(content) > self.settings.ocr_max_file_size_mb * 1024 * 1024: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=[ f"文件超过 {self.settings.ocr_max_file_size_mb} MB,未执行 OCR。" ], ) ) continue cache_key = self._build_cache_key(content) cached_document = self._read_cached_document( cache_key, filename=normalized_name, media_type=resolved_media_type, ) if cached_document is not None: documents.append(cached_document) continue temp_path = temp_root / f"{uuid4().hex}{suffix}" temp_path.write_bytes(content) cleanup_paths.append(temp_path) if suffix == ".pdf": try: text_layer = self._extract_pdf_text_layer(temp_path) pdf_inputs = self._prepare_pdf_inputs( pdf_path=temp_path, filename=normalized_name, media_type=resolved_media_type, cleanup_paths=cleanup_paths, text_layer=text_layer, ) if self._has_usable_pdf_text_layer(text_layer): document = self._build_text_layer_document( filename=normalized_name, media_type=resolved_media_type, text_layer=text_layer, pdf_inputs=pdf_inputs, ) documents.append(document) self._write_cached_document(cache_key, document) continue prepared_inputs.extend(pdf_inputs) for item in pdf_inputs: cache_keys_by_source.setdefault(item.source_key, cache_key) except RuntimeError as exc: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=[str(exc)], ) ) continue source_key = uuid4().hex prepared_inputs.append( PreparedOcrInput( input_path=temp_path, source_key=source_key, filename=normalized_name, media_type=resolved_media_type, preview_kind="image" if resolved_media_type.startswith("image/") else "", preview_data_url=( self._build_preview_data_url(temp_path, media_type=resolved_media_type) if resolved_media_type.startswith("image/") else "" ), ) ) cache_keys_by_source[source_key] = cache_key if prepared_inputs: python_bin = self._resolve_python_bin() worker_path = self._resolve_worker_path() worker_payload = self._invoke_worker( python_bin=python_bin, worker_path=worker_path, input_paths=[item.input_path for item in prepared_inputs], ) recognized_documents = self._build_documents( worker_documents=worker_payload.get("documents", []), prepared_inputs=prepared_inputs, ) documents.extend(recognized_documents) self._write_cached_documents( recognized_documents, prepared_inputs=prepared_inputs, cache_keys_by_source=cache_keys_by_source, ) success_count = sum( 1 for item in documents if item.line_count > 0 or not item.warnings ) engine = ( str(worker_payload.get("engine", "paddleocr_mobile")) if prepared_inputs else "paddleocr_mobile" ) model = ( str(worker_payload.get("model", "PP-OCRv5_mobile")) if prepared_inputs else "PP-OCRv5_mobile" ) return OcrRecognizeBatchRead( engine=engine, model=model, total_file_count=len(files), success_count=success_count, documents=documents, ) finally: self._cleanup_temp_paths(cleanup_paths) def _resolve_python_bin(self) -> str: candidates = [] configured = str(self.settings.ocr_python_bin or "").strip() if configured: candidates.append(configured) candidates.append(str(SERVER_DIR / ".venv-ocr312" / "bin" / "python")) candidates.append("/usr/local/bin/python3.12") resolved = shutil.which("python3.12") if resolved: candidates.append(resolved) for candidate in candidates: if candidate and Path(candidate).exists(): return candidate raise RuntimeError( "未找到可用的 OCR Python 运行时。请先执行 scripts/bootstrap_paddleocr_mobile.sh " "或通过 OCR_PYTHON_BIN 指向已安装 PaddleOCR 的 Python 3.12。" ) @staticmethod def _resolve_worker_path() -> str: worker_path = SERVER_DIR / "scripts" / "paddle_ocr_worker.py" if not worker_path.exists(): raise RuntimeError(f"OCR worker 不存在:{worker_path}") return str(worker_path) def _build_cache_key(self, content: bytes) -> str: digest = hashlib.sha256(content).hexdigest() return "|".join( [ self.settings.ocr_language, self.settings.ocr_device, self.settings.ocr_text_detection_model, self.settings.ocr_text_recognition_model, digest, ] ) @classmethod def _read_cached_document( cls, cache_key: str, *, filename: str, media_type: str, ) -> OcrRecognizeDocumentRead | None: if not cache_key: return None with cls._cache_lock: cached = cls._result_cache.get(cache_key) if cached is None: return None cls._result_cache.move_to_end(cache_key) return cached.model_copy(update={"filename": filename, "media_type": media_type}) @classmethod def _write_cached_documents( cls, documents: list[OcrRecognizeDocumentRead], *, prepared_inputs: list[PreparedOcrInput], cache_keys_by_source: dict[str, str], ) -> None: if not documents or not cache_keys_by_source: return source_order: list[str] = [] seen_sources: set[str] = set() for item in prepared_inputs: if item.source_key in seen_sources: continue seen_sources.add(item.source_key) source_order.append(item.source_key) with cls._cache_lock: for source_key, document in zip(source_order, documents, strict=False): cache_key = cache_keys_by_source.get(source_key, "") if not cache_key: continue cls._result_cache[cache_key] = document.model_copy( update={ "receipt_id": "", "receipt_status": "", "receipt_preview_url": "", "receipt_source_url": "", } ) cls._result_cache.move_to_end(cache_key) while len(cls._result_cache) > OCR_RESULT_CACHE_LIMIT: cls._result_cache.popitem(last=False) @classmethod def _write_cached_document(cls, cache_key: str, document: OcrRecognizeDocumentRead) -> None: if not cache_key: return with cls._cache_lock: cls._result_cache[cache_key] = document.model_copy( update={ "receipt_id": "", "receipt_status": "", "receipt_preview_url": "", "receipt_source_url": "", } ) cls._result_cache.move_to_end(cache_key) while len(cls._result_cache) > OCR_RESULT_CACHE_LIMIT: cls._result_cache.popitem(last=False) @classmethod def _resolve_worker_semaphore(cls, limit: int) -> threading.Semaphore: normalized_limit = max(1, int(limit or 1)) with cls._worker_semaphore_lock: if cls._worker_semaphore is None or cls._worker_semaphore_limit != normalized_limit: cls._worker_semaphore = threading.Semaphore(normalized_limit) cls._worker_semaphore_limit = normalized_limit return cls._worker_semaphore def _invoke_worker( self, *, python_bin: str, worker_path: str, input_paths: list[Path], ) -> dict: command = [ python_bin, worker_path, "--lang", self.settings.ocr_language, "--text-detection-model", self.settings.ocr_text_detection_model, "--text-recognition-model", self.settings.ocr_text_recognition_model, ] configured_device = str(self.settings.ocr_device or "").strip() if configured_device: command.extend(["--device", configured_device]) for path in input_paths: command.extend(["--input", str(path)]) semaphore = self._resolve_worker_semaphore(self.settings.ocr_max_concurrent_workers) with semaphore: completed = subprocess.run( command, capture_output=True, text=True, timeout=self.settings.ocr_timeout_seconds, check=False, ) if completed.returncode != 0: detail = (completed.stderr or completed.stdout or "").strip() raise RuntimeError(f"OCR 执行失败:{detail or 'worker 返回非 0 状态码。'}") payload = self._parse_worker_stdout(completed.stdout) if payload is None: raise RuntimeError("OCR worker 未返回可解析的 JSON 结果。") return payload @staticmethod def _parse_worker_stdout(stdout: str) -> dict | None: for line in reversed(stdout.splitlines()): normalized = line.strip() if normalized.startswith(WORKER_JSON_PREFIX): return json.loads(normalized[len(WORKER_JSON_PREFIX) :]) return None def _prepare_pdf_inputs( self, *, pdf_path: Path, filename: str, media_type: str, cleanup_paths: list[Path], text_layer: str = "", ) -> list[PreparedOcrInput]: output_dir = pdf_path.with_suffix("") output_dir.mkdir(parents=True, exist_ok=True) cleanup_paths.append(output_dir) image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir) if not image_paths: raise RuntimeError("PDF 转图片后未生成可识别页面。") preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png") source_key = uuid4().hex descriptors: list[PreparedOcrInput] = [] for page_index, image_path in enumerate(image_paths): descriptors.append( PreparedOcrInput( input_path=image_path, source_key=source_key, filename=filename, media_type=media_type, page_index=page_index, preview_kind="image" if page_index == 0 else "", preview_data_url=preview_data_url if page_index == 0 else "", text_layer=text_layer if page_index == 0 else "", ) ) return descriptors def _extract_pdf_text_layer(self, pdf_path: Path) -> str: try: completed = subprocess.run( [ "pdftotext", "-layout", str(pdf_path), "-", ], capture_output=True, text=True, timeout=self.settings.ocr_timeout_seconds, check=False, ) except (OSError, subprocess.SubprocessError, UnicodeError): return "" if completed.returncode != 0: return "" return self._normalize_extracted_text(completed.stdout) def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]: prefix = output_dir / "page" completed = subprocess.run( [ "pdftoppm", "-png", "-r", "160", str(pdf_path), str(prefix), ], capture_output=True, text=True, timeout=self.settings.ocr_timeout_seconds, check=False, ) if completed.returncode != 0: detail = (completed.stderr or completed.stdout or "").strip() raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}") return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key) @staticmethod def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]: suffix = path.stem.rsplit("-", 1)[-1] try: return int(suffix), path.name except ValueError: return 0, path.name @staticmethod def _build_preview_data_url(path: Path, *, media_type: str) -> str: encoded = base64.b64encode(path.read_bytes()).decode("ascii") return f"data:{media_type};base64,{encoded}" def _build_documents( self, *, worker_documents: list[dict], prepared_inputs: list[PreparedOcrInput], ) -> list[OcrRecognizeDocumentRead]: descriptor_by_path = {str(item.input_path): item for item in prepared_inputs} source_order: list[str] = [] seen_sources: set[str] = set() for item in prepared_inputs: if item.source_key in seen_sources: continue seen_sources.add(item.source_key) source_order.append(item.source_key) aggregated_by_source: dict[str, AggregatedOcrDocument] = {} for payload in worker_documents: if not isinstance(payload, dict): continue input_path = str(payload.get("input_path") or "") descriptor = descriptor_by_path.get(input_path) if descriptor is None: continue aggregated = aggregated_by_source.get(descriptor.source_key) if aggregated is None: aggregated = AggregatedOcrDocument( filename=descriptor.filename, media_type=descriptor.media_type, source_key=descriptor.source_key, engine=str(payload.get("engine", "paddleocr_mobile")), model=str(payload.get("model", "PP-OCRv5_mobile")), ) aggregated_by_source[descriptor.source_key] = aggregated aggregated.page_count = max( aggregated.page_count, (descriptor.page_index + 1) if descriptor.page_index is not None else int(payload.get("page_count", 1) or 1), ) if descriptor.preview_kind and not aggregated.preview_kind: aggregated.preview_kind = descriptor.preview_kind if descriptor.preview_data_url and not aggregated.preview_data_url: aggregated.preview_data_url = descriptor.preview_data_url if descriptor.text_layer and descriptor.text_layer not in aggregated.text_layer_fragments: aggregated.text_layer_fragments.append(descriptor.text_layer) page_summary = str(payload.get("summary", "") or "").strip() if page_summary: aggregated.summary_fragments.append(page_summary) page_text = self._resolve_worker_document_text(payload) if page_text: aggregated.text_fragments.append(page_text) lines = self._build_lines( payload.get("lines", []), page_index_override=descriptor.page_index, ) aggregated.lines.extend(lines) aggregated.score_values.extend(line.score for line in lines if line.score > 0) if not lines: avg_score = float(payload.get("avg_score", 0.0) or 0.0) if avg_score > 0: aggregated.score_values.append(avg_score) for warning in payload.get("warnings", []): normalized_warning = str(warning or "").strip() if normalized_warning and normalized_warning not in aggregated.warnings: aggregated.warnings.append(normalized_warning) documents: list[OcrRecognizeDocumentRead] = [] for source_key in source_order: descriptors = [item for item in prepared_inputs if item.source_key == source_key] if not descriptors: continue aggregated = aggregated_by_source.get(source_key) if aggregated is None: first_descriptor = descriptors[0] text_layer = self._collect_descriptor_text_layer(descriptors) if text_layer: fallback = AggregatedOcrDocument( filename=first_descriptor.filename, media_type=first_descriptor.media_type, source_key=first_descriptor.source_key, page_count=max(1, len(descriptors)), preview_kind=first_descriptor.preview_kind, preview_data_url=first_descriptor.preview_data_url, warnings=["OCR worker 未返回该文件的识别结果,已使用 PDF 文本层。"], ) fallback.text_layer_fragments.append(text_layer) documents.append(self._finalize_document(fallback)) continue documents.append( OcrRecognizeDocumentRead( filename=first_descriptor.filename, media_type=first_descriptor.media_type, page_count=max(1, len(descriptors)), preview_kind=first_descriptor.preview_kind, preview_data_url=first_descriptor.preview_data_url, warnings=["OCR worker 未返回该文件的识别结果。"], ) ) continue documents.append(self._finalize_document(aggregated)) return documents def _build_text_layer_document( self, *, filename: str, media_type: str, text_layer: str, pdf_inputs: list[PreparedOcrInput], ) -> OcrRecognizeDocumentRead: first_input = pdf_inputs[0] if pdf_inputs else None aggregated = AggregatedOcrDocument( filename=filename, media_type=media_type, source_key=first_input.source_key if first_input is not None else uuid4().hex, page_count=max(1, len(pdf_inputs)), preview_kind=str(first_input.preview_kind if first_input is not None else ""), preview_data_url=str(first_input.preview_data_url if first_input is not None else ""), ) aggregated.text_layer_fragments.append(text_layer) return self._finalize_document(aggregated) @classmethod def _has_usable_pdf_text_layer(cls, text_layer: str) -> bool: return cls._meaningful_char_count(text_layer) >= 8 @staticmethod def _collect_descriptor_text_layer(descriptors: list[PreparedOcrInput]) -> str: for descriptor in descriptors: if descriptor.text_layer: return descriptor.text_layer return "" @staticmethod def _resolve_worker_document_text(payload: dict) -> str: for key in ("text", "ocr_text", "raw_text", "full_text"): value = str(payload.get(key, "") or "").strip() if value: return value lines = payload.get("lines", []) if not isinstance(lines, list): return "" return "\n".join( str(item.get("text", "") or "").strip() for item in lines if isinstance(item, dict) and str(item.get("text", "") or "").strip() ).strip() @staticmethod def _build_lines( items: list[dict], *, page_index_override: int | None = None, ) -> list[OcrRecognizeLineRead]: lines: list[OcrRecognizeLineRead] = [] for item in items: if not isinstance(item, dict): continue page_index = page_index_override if page_index is None and item.get("page_index") is not None: page_index = int(item["page_index"]) lines.append( OcrRecognizeLineRead( text=str(item.get("text", "")), score=float(item.get("score", 0.0) or 0.0), box=[ [int(point[0]), int(point[1])] for point in item.get("box", []) if isinstance(point, list) and len(point) == 2 ], page_index=page_index, ) ) return lines @staticmethod def _truncate_summary(parts: list[str]) -> str: summary = ";".join([part for part in parts if part][:3]) if len(summary) > 180: return f"{summary[:177]}..." return summary def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead: ocr_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip() text_layer = "\n".join(fragment for fragment in aggregated.text_layer_fragments if fragment).strip() full_text, used_text_layer = self._choose_document_text(ocr_text=ocr_text, text_layer=text_layer) summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments) if used_text_layer or self._placeholder_ratio(summary) >= 0.12: summary = self._summarize_text(full_text) preview_kind = aggregated.preview_kind preview_data_url = aggregated.preview_data_url if ( used_text_layer and aggregated.media_type == "application/pdf" and self._placeholder_ratio(ocr_text) >= 0.12 ): preview_kind = "" preview_data_url = "" insight = self.document_intelligence_service.build_document_insight( filename=aggregated.filename, summary=summary, text=full_text, preview_data_url=preview_data_url, ) warnings = list(aggregated.warnings) for warning in insight.warnings: normalized_warning = str(warning or "").strip() if normalized_warning and normalized_warning not in warnings: warnings.append(normalized_warning) return OcrRecognizeDocumentRead( filename=aggregated.filename, media_type=aggregated.media_type, engine=aggregated.engine, model=aggregated.model, text=full_text, summary=summary, avg_score=( sum(aggregated.score_values) / len(aggregated.score_values) if aggregated.score_values else 0.0 ), line_count=len(aggregated.lines), page_count=max(1, aggregated.page_count), document_type=insight.document_type, document_type_label=insight.document_type_label, scene_code=insight.scene_code, scene_label=insight.scene_label, classification_source=insight.classification_source, classification_confidence=insight.classification_confidence, classification_evidence=list(insight.evidence), document_fields=[ OcrRecognizeFieldRead( key=field.key, label=field.label, value=field.value, ) for field in insight.fields ], preview_kind=preview_kind, preview_data_url=preview_data_url, warnings=warnings, lines=sorted( aggregated.lines, key=lambda item: item.page_index if item.page_index is not None else -1, ), ) @classmethod def _choose_document_text(cls, *, ocr_text: str, text_layer: str) -> tuple[str, bool]: normalized_ocr_text = cls._normalize_extracted_text(ocr_text) normalized_text_layer = cls._normalize_extracted_text(text_layer) if not normalized_text_layer: return normalized_ocr_text, False if not normalized_ocr_text: return normalized_text_layer, True if cls._placeholder_ratio(normalized_ocr_text) >= 0.12 and cls._meaningful_char_count(normalized_text_layer) >= 8: return normalized_text_layer, True if cls._meaningful_char_count(normalized_text_layer) > cls._meaningful_char_count(normalized_ocr_text) * 1.3: return normalized_text_layer, True return normalized_ocr_text, False @staticmethod def _normalize_extracted_text(value: str) -> str: lines = [re.sub(r"[ \t]+", " ", line).strip() for line in str(value or "").replace("\r", "\n").split("\n")] return "\n".join(line for line in lines if line).strip() @staticmethod def _summarize_text(value: str) -> str: lines = [line.strip() for line in str(value or "").splitlines() if line.strip()] summary = ";".join(lines[:3]) if len(summary) > 180: return f"{summary[:177]}..." return summary @staticmethod def _meaningful_char_count(value: str) -> int: return len(re.findall(r"[0-9A-Za-z\u4e00-\u9fff]", str(value or ""))) @staticmethod def _placeholder_ratio(value: str) -> float: chars = [char for char in str(value or "") if not char.isspace()] if not chars: return 0.0 placeholder_count = sum(1 for char in chars if char in {"□", "�"}) return placeholder_count / len(chars) @staticmethod def _cleanup_temp_paths(paths: list[Path]) -> None: for path in reversed(paths): if path.is_dir(): shutil.rmtree(path, ignore_errors=True) continue path.unlink(missing_ok=True)