222 lines
8.2 KiB
Python
222 lines
8.2 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import shutil
|
|||
|
|
import subprocess
|
|||
|
|
from pathlib import Path
|
|||
|
|
from uuid import uuid4
|
|||
|
|
|
|||
|
|
from app.core.config import SERVER_DIR, get_settings
|
|||
|
|
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
|
|||
|
|
|
|||
|
|
WORKER_JSON_PREFIX = "__OCR_JSON__="
|
|||
|
|
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OcrService:
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self.settings = get_settings()
|
|||
|
|
|
|||
|
|
def recognize_files(
|
|||
|
|
self,
|
|||
|
|
files: list[tuple[str, bytes, str | None]],
|
|||
|
|
) -> OcrRecognizeBatchRead:
|
|||
|
|
if not files:
|
|||
|
|
raise ValueError("至少需要上传一个文件。")
|
|||
|
|
|
|||
|
|
temp_root = self.settings.resolved_ocr_temp_dir
|
|||
|
|
temp_root.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
documents: list[OcrRecognizeDocumentRead] = []
|
|||
|
|
input_paths: list[Path] = []
|
|||
|
|
meta_by_path: dict[str, tuple[str, str]] = {}
|
|||
|
|
python_bin = self._resolve_python_bin()
|
|||
|
|
worker_path = self._resolve_worker_path()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
for filename, content, media_type in files:
|
|||
|
|
normalized_name = Path(str(filename or "").strip()).name or "upload.bin"
|
|||
|
|
suffix = Path(normalized_name).suffix.lower()
|
|||
|
|
resolved_media_type = str(media_type or "application/octet-stream")
|
|||
|
|
|
|||
|
|
if not content:
|
|||
|
|
documents.append(
|
|||
|
|
OcrRecognizeDocumentRead(
|
|||
|
|
filename=normalized_name,
|
|||
|
|
media_type=resolved_media_type,
|
|||
|
|
warnings=["文件内容为空,未执行 OCR。"],
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if suffix not in SUPPORTED_SUFFIXES:
|
|||
|
|
documents.append(
|
|||
|
|
OcrRecognizeDocumentRead(
|
|||
|
|
filename=normalized_name,
|
|||
|
|
media_type=resolved_media_type,
|
|||
|
|
warnings=["当前仅支持图片和 PDF 文件进行 OCR。"],
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if len(content) > self.settings.ocr_max_file_size_mb * 1024 * 1024:
|
|||
|
|
documents.append(
|
|||
|
|
OcrRecognizeDocumentRead(
|
|||
|
|
filename=normalized_name,
|
|||
|
|
media_type=resolved_media_type,
|
|||
|
|
warnings=[
|
|||
|
|
f"文件超过 {self.settings.ocr_max_file_size_mb} MB,未执行 OCR。"
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
temp_path = temp_root / f"{uuid4().hex}{suffix}"
|
|||
|
|
temp_path.write_bytes(content)
|
|||
|
|
input_paths.append(temp_path)
|
|||
|
|
meta_by_path[str(temp_path)] = (normalized_name, resolved_media_type)
|
|||
|
|
|
|||
|
|
if input_paths:
|
|||
|
|
worker_payload = self._invoke_worker(
|
|||
|
|
python_bin=python_bin,
|
|||
|
|
worker_path=worker_path,
|
|||
|
|
input_paths=input_paths,
|
|||
|
|
)
|
|||
|
|
for item in worker_payload.get("documents", []):
|
|||
|
|
documents.append(self._build_document(item, meta_by_path))
|
|||
|
|
|
|||
|
|
success_count = sum(
|
|||
|
|
1
|
|||
|
|
for item in documents
|
|||
|
|
if item.line_count > 0 or not item.warnings
|
|||
|
|
)
|
|||
|
|
engine = (
|
|||
|
|
str(worker_payload.get("engine", "paddleocr_mobile"))
|
|||
|
|
if input_paths
|
|||
|
|
else "paddleocr_mobile"
|
|||
|
|
)
|
|||
|
|
model = (
|
|||
|
|
str(worker_payload.get("model", "PP-OCRv5_mobile"))
|
|||
|
|
if input_paths
|
|||
|
|
else "PP-OCRv5_mobile"
|
|||
|
|
)
|
|||
|
|
return OcrRecognizeBatchRead(
|
|||
|
|
engine=engine,
|
|||
|
|
model=model,
|
|||
|
|
total_file_count=len(files),
|
|||
|
|
success_count=success_count,
|
|||
|
|
documents=documents,
|
|||
|
|
)
|
|||
|
|
finally:
|
|||
|
|
for path in input_paths:
|
|||
|
|
path.unlink(missing_ok=True)
|
|||
|
|
|
|||
|
|
def _resolve_python_bin(self) -> str:
|
|||
|
|
candidates = []
|
|||
|
|
configured = str(self.settings.ocr_python_bin or "").strip()
|
|||
|
|
if configured:
|
|||
|
|
candidates.append(configured)
|
|||
|
|
candidates.append(str(SERVER_DIR / ".venv-ocr312" / "bin" / "python"))
|
|||
|
|
candidates.append("/usr/local/bin/python3.12")
|
|||
|
|
resolved = shutil.which("python3.12")
|
|||
|
|
if resolved:
|
|||
|
|
candidates.append(resolved)
|
|||
|
|
|
|||
|
|
for candidate in candidates:
|
|||
|
|
if candidate and Path(candidate).exists():
|
|||
|
|
return candidate
|
|||
|
|
|
|||
|
|
raise RuntimeError(
|
|||
|
|
"未找到可用的 OCR Python 运行时。请先执行 scripts/bootstrap_paddleocr_mobile.sh "
|
|||
|
|
"或通过 OCR_PYTHON_BIN 指向已安装 PaddleOCR 的 Python 3.12。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _resolve_worker_path() -> str:
|
|||
|
|
worker_path = SERVER_DIR / "scripts" / "paddle_ocr_worker.py"
|
|||
|
|
if not worker_path.exists():
|
|||
|
|
raise RuntimeError(f"OCR worker 不存在:{worker_path}")
|
|||
|
|
return str(worker_path)
|
|||
|
|
|
|||
|
|
def _invoke_worker(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
python_bin: str,
|
|||
|
|
worker_path: str,
|
|||
|
|
input_paths: list[Path],
|
|||
|
|
) -> dict:
|
|||
|
|
command = [
|
|||
|
|
python_bin,
|
|||
|
|
worker_path,
|
|||
|
|
"--lang",
|
|||
|
|
self.settings.ocr_language,
|
|||
|
|
"--text-detection-model",
|
|||
|
|
self.settings.ocr_text_detection_model,
|
|||
|
|
"--text-recognition-model",
|
|||
|
|
self.settings.ocr_text_recognition_model,
|
|||
|
|
]
|
|||
|
|
for path in input_paths:
|
|||
|
|
command.extend(["--input", str(path)])
|
|||
|
|
|
|||
|
|
completed = subprocess.run(
|
|||
|
|
command,
|
|||
|
|
capture_output=True,
|
|||
|
|
text=True,
|
|||
|
|
timeout=self.settings.ocr_timeout_seconds,
|
|||
|
|
check=False,
|
|||
|
|
)
|
|||
|
|
if completed.returncode != 0:
|
|||
|
|
detail = (completed.stderr or completed.stdout or "").strip()
|
|||
|
|
raise RuntimeError(f"OCR 执行失败:{detail or 'worker 返回非 0 状态码。'}")
|
|||
|
|
|
|||
|
|
payload = self._parse_worker_stdout(completed.stdout)
|
|||
|
|
if payload is None:
|
|||
|
|
raise RuntimeError("OCR worker 未返回可解析的 JSON 结果。")
|
|||
|
|
return payload
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_worker_stdout(stdout: str) -> dict | None:
|
|||
|
|
for line in reversed(stdout.splitlines()):
|
|||
|
|
normalized = line.strip()
|
|||
|
|
if normalized.startswith(WORKER_JSON_PREFIX):
|
|||
|
|
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _build_document(
|
|||
|
|
payload: dict,
|
|||
|
|
meta_by_path: dict[str, tuple[str, str]],
|
|||
|
|
) -> OcrRecognizeDocumentRead:
|
|||
|
|
input_path = str(payload.get("input_path") or "")
|
|||
|
|
filename, media_type = meta_by_path.get(
|
|||
|
|
input_path,
|
|||
|
|
(Path(input_path).name or "upload.bin", "application/octet-stream"),
|
|||
|
|
)
|
|||
|
|
lines = [
|
|||
|
|
OcrRecognizeLineRead(
|
|||
|
|
text=str(item.get("text", "")),
|
|||
|
|
score=float(item.get("score", 0.0) or 0.0),
|
|||
|
|
box=[
|
|||
|
|
[int(point[0]), int(point[1])]
|
|||
|
|
for point in item.get("box", [])
|
|||
|
|
if isinstance(point, list) and len(point) == 2
|
|||
|
|
],
|
|||
|
|
page_index=int(item["page_index"]) if item.get("page_index") is not None else None,
|
|||
|
|
)
|
|||
|
|
for item in payload.get("lines", [])
|
|||
|
|
if isinstance(item, dict)
|
|||
|
|
]
|
|||
|
|
return OcrRecognizeDocumentRead(
|
|||
|
|
filename=filename,
|
|||
|
|
media_type=media_type,
|
|||
|
|
engine=str(payload.get("engine", "paddleocr_mobile")),
|
|||
|
|
model=str(payload.get("model", "PP-OCRv5_mobile")),
|
|||
|
|
text=str(payload.get("text", "")),
|
|||
|
|
summary=str(payload.get("summary", "")),
|
|||
|
|
avg_score=float(payload.get("avg_score", 0.0) or 0.0),
|
|||
|
|
line_count=int(payload.get("line_count", len(lines)) or 0),
|
|||
|
|
page_count=int(payload.get("page_count", 1) or 1),
|
|||
|
|
warnings=[str(item) for item in payload.get("warnings", [])],
|
|||
|
|
lines=lines,
|
|||
|
|
)
|