from __future__ import annotations import base64 import json import shutil import subprocess from dataclasses import dataclass, field from pathlib import Path from uuid import uuid4 from sqlalchemy.orm import Session from app.core.config import SERVER_DIR, get_settings from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead from app.services.document_intelligence import DocumentIntelligenceService WORKER_JSON_PREFIX = "__OCR_JSON__=" SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"} @dataclass(slots=True) class PreparedOcrInput: input_path: Path source_key: str filename: str media_type: str page_index: int | None = None preview_kind: str = "" preview_data_url: str = "" @dataclass(slots=True) class AggregatedOcrDocument: filename: str media_type: str source_key: str engine: str = "paddleocr_mobile" model: str = "PP-OCRv5_mobile" summary_fragments: list[str] = field(default_factory=list) text_fragments: list[str] = field(default_factory=list) score_values: list[float] = field(default_factory=list) warnings: list[str] = field(default_factory=list) lines: list[OcrRecognizeLineRead] = field(default_factory=list) page_count: int = 0 preview_kind: str = "" preview_data_url: str = "" class OcrService: def __init__(self, db: Session | None = None) -> None: self.settings = get_settings() self.document_intelligence_service = DocumentIntelligenceService(db) def recognize_files( self, files: list[tuple[str, bytes, str | None]], ) -> OcrRecognizeBatchRead: if not files: raise ValueError("至少需要上传一个文件。") temp_root = self.settings.resolved_ocr_temp_dir temp_root.mkdir(parents=True, exist_ok=True) documents: list[OcrRecognizeDocumentRead] = [] prepared_inputs: list[PreparedOcrInput] = [] cleanup_paths: list[Path] = [] python_bin = self._resolve_python_bin() worker_path = self._resolve_worker_path() worker_payload: dict = {} try: for filename, content, media_type in files: normalized_name = Path(str(filename or "").strip()).name or "upload.bin" suffix = Path(normalized_name).suffix.lower() resolved_media_type = str(media_type or "application/octet-stream") if not content: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=["文件内容为空,未执行 OCR。"], ) ) continue if suffix not in SUPPORTED_SUFFIXES: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=["当前仅支持图片和 PDF 文件进行 OCR。"], ) ) continue if len(content) > self.settings.ocr_max_file_size_mb * 1024 * 1024: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=[ f"文件超过 {self.settings.ocr_max_file_size_mb} MB,未执行 OCR。" ], ) ) continue temp_path = temp_root / f"{uuid4().hex}{suffix}" temp_path.write_bytes(content) cleanup_paths.append(temp_path) if suffix == ".pdf": try: prepared_inputs.extend( self._prepare_pdf_inputs( pdf_path=temp_path, filename=normalized_name, media_type=resolved_media_type, cleanup_paths=cleanup_paths, ) ) except RuntimeError as exc: documents.append( OcrRecognizeDocumentRead( filename=normalized_name, media_type=resolved_media_type, warnings=[str(exc)], ) ) continue prepared_inputs.append( PreparedOcrInput( input_path=temp_path, source_key=uuid4().hex, filename=normalized_name, media_type=resolved_media_type, preview_kind="image" if resolved_media_type.startswith("image/") else "", preview_data_url=( self._build_preview_data_url(temp_path, media_type=resolved_media_type) if resolved_media_type.startswith("image/") else "" ), ) ) if prepared_inputs: worker_payload = self._invoke_worker( python_bin=python_bin, worker_path=worker_path, input_paths=[item.input_path for item in prepared_inputs], ) documents.extend( self._build_documents( worker_documents=worker_payload.get("documents", []), prepared_inputs=prepared_inputs, ) ) success_count = sum( 1 for item in documents if item.line_count > 0 or not item.warnings ) engine = ( str(worker_payload.get("engine", "paddleocr_mobile")) if prepared_inputs else "paddleocr_mobile" ) model = ( str(worker_payload.get("model", "PP-OCRv5_mobile")) if prepared_inputs else "PP-OCRv5_mobile" ) return OcrRecognizeBatchRead( engine=engine, model=model, total_file_count=len(files), success_count=success_count, documents=documents, ) finally: self._cleanup_temp_paths(cleanup_paths) def _resolve_python_bin(self) -> str: candidates = [] configured = str(self.settings.ocr_python_bin or "").strip() if configured: candidates.append(configured) candidates.append(str(SERVER_DIR / ".venv-ocr312" / "bin" / "python")) candidates.append("/usr/local/bin/python3.12") resolved = shutil.which("python3.12") if resolved: candidates.append(resolved) for candidate in candidates: if candidate and Path(candidate).exists(): return candidate raise RuntimeError( "未找到可用的 OCR Python 运行时。请先执行 scripts/bootstrap_paddleocr_mobile.sh " "或通过 OCR_PYTHON_BIN 指向已安装 PaddleOCR 的 Python 3.12。" ) @staticmethod def _resolve_worker_path() -> str: worker_path = SERVER_DIR / "scripts" / "paddle_ocr_worker.py" if not worker_path.exists(): raise RuntimeError(f"OCR worker 不存在:{worker_path}") return str(worker_path) def _invoke_worker( self, *, python_bin: str, worker_path: str, input_paths: list[Path], ) -> dict: command = [ python_bin, worker_path, "--lang", self.settings.ocr_language, "--text-detection-model", self.settings.ocr_text_detection_model, "--text-recognition-model", self.settings.ocr_text_recognition_model, ] for path in input_paths: command.extend(["--input", str(path)]) completed = subprocess.run( command, capture_output=True, text=True, timeout=self.settings.ocr_timeout_seconds, check=False, ) if completed.returncode != 0: detail = (completed.stderr or completed.stdout or "").strip() raise RuntimeError(f"OCR 执行失败:{detail or 'worker 返回非 0 状态码。'}") payload = self._parse_worker_stdout(completed.stdout) if payload is None: raise RuntimeError("OCR worker 未返回可解析的 JSON 结果。") return payload @staticmethod def _parse_worker_stdout(stdout: str) -> dict | None: for line in reversed(stdout.splitlines()): normalized = line.strip() if normalized.startswith(WORKER_JSON_PREFIX): return json.loads(normalized[len(WORKER_JSON_PREFIX) :]) return None def _prepare_pdf_inputs( self, *, pdf_path: Path, filename: str, media_type: str, cleanup_paths: list[Path], ) -> list[PreparedOcrInput]: output_dir = pdf_path.with_suffix("") output_dir.mkdir(parents=True, exist_ok=True) cleanup_paths.append(output_dir) image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir) if not image_paths: raise RuntimeError("PDF 转图片后未生成可识别页面。") preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png") source_key = uuid4().hex descriptors: list[PreparedOcrInput] = [] for page_index, image_path in enumerate(image_paths): descriptors.append( PreparedOcrInput( input_path=image_path, source_key=source_key, filename=filename, media_type=media_type, page_index=page_index, preview_kind="image" if page_index == 0 else "", preview_data_url=preview_data_url if page_index == 0 else "", ) ) return descriptors def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]: prefix = output_dir / "page" completed = subprocess.run( [ "pdftoppm", "-png", "-r", "160", str(pdf_path), str(prefix), ], capture_output=True, text=True, timeout=self.settings.ocr_timeout_seconds, check=False, ) if completed.returncode != 0: detail = (completed.stderr or completed.stdout or "").strip() raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}") return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key) @staticmethod def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]: suffix = path.stem.rsplit("-", 1)[-1] try: return int(suffix), path.name except ValueError: return 0, path.name @staticmethod def _build_preview_data_url(path: Path, *, media_type: str) -> str: encoded = base64.b64encode(path.read_bytes()).decode("ascii") return f"data:{media_type};base64,{encoded}" def _build_documents( self, *, worker_documents: list[dict], prepared_inputs: list[PreparedOcrInput], ) -> list[OcrRecognizeDocumentRead]: descriptor_by_path = {str(item.input_path): item for item in prepared_inputs} source_order: list[str] = [] seen_sources: set[str] = set() for item in prepared_inputs: if item.source_key in seen_sources: continue seen_sources.add(item.source_key) source_order.append(item.source_key) aggregated_by_source: dict[str, AggregatedOcrDocument] = {} for payload in worker_documents: if not isinstance(payload, dict): continue input_path = str(payload.get("input_path") or "") descriptor = descriptor_by_path.get(input_path) if descriptor is None: continue aggregated = aggregated_by_source.get(descriptor.source_key) if aggregated is None: aggregated = AggregatedOcrDocument( filename=descriptor.filename, media_type=descriptor.media_type, source_key=descriptor.source_key, engine=str(payload.get("engine", "paddleocr_mobile")), model=str(payload.get("model", "PP-OCRv5_mobile")), ) aggregated_by_source[descriptor.source_key] = aggregated aggregated.page_count = max( aggregated.page_count, (descriptor.page_index + 1) if descriptor.page_index is not None else int(payload.get("page_count", 1) or 1), ) if descriptor.preview_kind and not aggregated.preview_kind: aggregated.preview_kind = descriptor.preview_kind if descriptor.preview_data_url and not aggregated.preview_data_url: aggregated.preview_data_url = descriptor.preview_data_url page_summary = str(payload.get("summary", "") or "").strip() if page_summary: aggregated.summary_fragments.append(page_summary) page_text = str(payload.get("text", "") or "").strip() if page_text: aggregated.text_fragments.append(page_text) lines = self._build_lines( payload.get("lines", []), page_index_override=descriptor.page_index, ) aggregated.lines.extend(lines) aggregated.score_values.extend(line.score for line in lines if line.score > 0) if not lines: avg_score = float(payload.get("avg_score", 0.0) or 0.0) if avg_score > 0: aggregated.score_values.append(avg_score) for warning in payload.get("warnings", []): normalized_warning = str(warning or "").strip() if normalized_warning and normalized_warning not in aggregated.warnings: aggregated.warnings.append(normalized_warning) documents: list[OcrRecognizeDocumentRead] = [] for source_key in source_order: descriptors = [item for item in prepared_inputs if item.source_key == source_key] if not descriptors: continue aggregated = aggregated_by_source.get(source_key) if aggregated is None: first_descriptor = descriptors[0] documents.append( OcrRecognizeDocumentRead( filename=first_descriptor.filename, media_type=first_descriptor.media_type, page_count=max(1, len(descriptors)), preview_kind=first_descriptor.preview_kind, preview_data_url=first_descriptor.preview_data_url, warnings=["OCR worker 未返回该文件的识别结果。"], ) ) continue documents.append(self._finalize_document(aggregated)) return documents @staticmethod def _build_lines( items: list[dict], *, page_index_override: int | None = None, ) -> list[OcrRecognizeLineRead]: lines: list[OcrRecognizeLineRead] = [] for item in items: if not isinstance(item, dict): continue page_index = page_index_override if page_index is None and item.get("page_index") is not None: page_index = int(item["page_index"]) lines.append( OcrRecognizeLineRead( text=str(item.get("text", "")), score=float(item.get("score", 0.0) or 0.0), box=[ [int(point[0]), int(point[1])] for point in item.get("box", []) if isinstance(point, list) and len(point) == 2 ], page_index=page_index, ) ) return lines @staticmethod def _truncate_summary(parts: list[str]) -> str: summary = ";".join([part for part in parts if part][:3]) if len(summary) > 180: return f"{summary[:177]}..." return summary def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead: full_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip() summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments) insight = self.document_intelligence_service.build_document_insight( filename=aggregated.filename, summary=summary, text=full_text, preview_data_url=aggregated.preview_data_url, ) warnings = list(aggregated.warnings) for warning in insight.warnings: normalized_warning = str(warning or "").strip() if normalized_warning and normalized_warning not in warnings: warnings.append(normalized_warning) return OcrRecognizeDocumentRead( filename=aggregated.filename, media_type=aggregated.media_type, engine=aggregated.engine, model=aggregated.model, text=full_text, summary=summary, avg_score=( sum(aggregated.score_values) / len(aggregated.score_values) if aggregated.score_values else 0.0 ), line_count=len(aggregated.lines), page_count=max(1, aggregated.page_count), document_type=insight.document_type, document_type_label=insight.document_type_label, scene_code=insight.scene_code, scene_label=insight.scene_label, classification_source=insight.classification_source, classification_confidence=insight.classification_confidence, classification_evidence=list(insight.evidence), document_fields=[ OcrRecognizeFieldRead( key=field.key, label=field.label, value=field.value, ) for field in insight.fields ], preview_kind=aggregated.preview_kind, preview_data_url=aggregated.preview_data_url, warnings=warnings, lines=sorted( aggregated.lines, key=lambda item: item.page_index if item.page_index is not None else -1, ), ) @staticmethod def _cleanup_temp_paths(paths: list[Path]) -> None: for path in reversed(paths): if path.is_dir(): shutil.rmtree(path, ignore_errors=True) continue path.unlink(missing_ok=True)