Files
X-Financial/server/src/app/services/ocr.py

512 lines
20 KiB
Python
Raw Normal View History

from __future__ import annotations
import base64
import json
import shutil
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from uuid import uuid4
from sqlalchemy.orm import Session
from app.core.config import SERVER_DIR, get_settings
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
from app.services.document_intelligence import DocumentIntelligenceService
WORKER_JSON_PREFIX = "__OCR_JSON__="
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
@dataclass(slots=True)
class PreparedOcrInput:
input_path: Path
source_key: str
filename: str
media_type: str
page_index: int | None = None
preview_kind: str = ""
preview_data_url: str = ""
@dataclass(slots=True)
class AggregatedOcrDocument:
filename: str
media_type: str
source_key: str
engine: str = "paddleocr_mobile"
model: str = "PP-OCRv5_mobile"
summary_fragments: list[str] = field(default_factory=list)
text_fragments: list[str] = field(default_factory=list)
score_values: list[float] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
lines: list[OcrRecognizeLineRead] = field(default_factory=list)
page_count: int = 0
preview_kind: str = ""
preview_data_url: str = ""
class OcrService:
def __init__(self, db: Session | None = None) -> None:
self.settings = get_settings()
self.document_intelligence_service = DocumentIntelligenceService(db)
def recognize_files(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
if not files:
raise ValueError("至少需要上传一个文件。")
temp_root = self.settings.resolved_ocr_temp_dir
temp_root.mkdir(parents=True, exist_ok=True)
documents: list[OcrRecognizeDocumentRead] = []
prepared_inputs: list[PreparedOcrInput] = []
cleanup_paths: list[Path] = []
python_bin = self._resolve_python_bin()
worker_path = self._resolve_worker_path()
worker_payload: dict = {}
try:
for filename, content, media_type in files:
normalized_name = Path(str(filename or "").strip()).name or "upload.bin"
suffix = Path(normalized_name).suffix.lower()
resolved_media_type = str(media_type or "application/octet-stream")
if not content:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=["文件内容为空,未执行 OCR。"],
)
)
continue
if suffix not in SUPPORTED_SUFFIXES:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=["当前仅支持图片和 PDF 文件进行 OCR。"],
)
)
continue
if len(content) > self.settings.ocr_max_file_size_mb * 1024 * 1024:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=[
f"文件超过 {self.settings.ocr_max_file_size_mb} MB未执行 OCR。"
],
)
)
continue
temp_path = temp_root / f"{uuid4().hex}{suffix}"
temp_path.write_bytes(content)
cleanup_paths.append(temp_path)
if suffix == ".pdf":
try:
prepared_inputs.extend(
self._prepare_pdf_inputs(
pdf_path=temp_path,
filename=normalized_name,
media_type=resolved_media_type,
cleanup_paths=cleanup_paths,
)
)
except RuntimeError as exc:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=[str(exc)],
)
)
continue
prepared_inputs.append(
PreparedOcrInput(
input_path=temp_path,
source_key=uuid4().hex,
filename=normalized_name,
media_type=resolved_media_type,
preview_kind="image" if resolved_media_type.startswith("image/") else "",
preview_data_url=(
self._build_preview_data_url(temp_path, media_type=resolved_media_type)
if resolved_media_type.startswith("image/")
else ""
),
)
)
if prepared_inputs:
worker_payload = self._invoke_worker(
python_bin=python_bin,
worker_path=worker_path,
input_paths=[item.input_path for item in prepared_inputs],
)
documents.extend(
self._build_documents(
worker_documents=worker_payload.get("documents", []),
prepared_inputs=prepared_inputs,
)
)
success_count = sum(
1
for item in documents
if item.line_count > 0 or not item.warnings
)
engine = (
str(worker_payload.get("engine", "paddleocr_mobile"))
if prepared_inputs
else "paddleocr_mobile"
)
model = (
str(worker_payload.get("model", "PP-OCRv5_mobile"))
if prepared_inputs
else "PP-OCRv5_mobile"
)
return OcrRecognizeBatchRead(
engine=engine,
model=model,
total_file_count=len(files),
success_count=success_count,
documents=documents,
)
finally:
self._cleanup_temp_paths(cleanup_paths)
def _resolve_python_bin(self) -> str:
candidates = []
configured = str(self.settings.ocr_python_bin or "").strip()
if configured:
candidates.append(configured)
candidates.append(str(SERVER_DIR / ".venv-ocr312" / "bin" / "python"))
candidates.append("/usr/local/bin/python3.12")
resolved = shutil.which("python3.12")
if resolved:
candidates.append(resolved)
for candidate in candidates:
if candidate and Path(candidate).exists():
return candidate
raise RuntimeError(
"未找到可用的 OCR Python 运行时。请先执行 scripts/bootstrap_paddleocr_mobile.sh "
"或通过 OCR_PYTHON_BIN 指向已安装 PaddleOCR 的 Python 3.12。"
)
@staticmethod
def _resolve_worker_path() -> str:
worker_path = SERVER_DIR / "scripts" / "paddle_ocr_worker.py"
if not worker_path.exists():
raise RuntimeError(f"OCR worker 不存在:{worker_path}")
return str(worker_path)
def _invoke_worker(
self,
*,
python_bin: str,
worker_path: str,
input_paths: list[Path],
) -> dict:
command = [
python_bin,
worker_path,
"--lang",
self.settings.ocr_language,
"--text-detection-model",
self.settings.ocr_text_detection_model,
"--text-recognition-model",
self.settings.ocr_text_recognition_model,
]
for path in input_paths:
command.extend(["--input", str(path)])
completed = subprocess.run(
command,
capture_output=True,
text=True,
timeout=self.settings.ocr_timeout_seconds,
check=False,
)
if completed.returncode != 0:
detail = (completed.stderr or completed.stdout or "").strip()
raise RuntimeError(f"OCR 执行失败:{detail or 'worker 返回非 0 状态码。'}")
payload = self._parse_worker_stdout(completed.stdout)
if payload is None:
raise RuntimeError("OCR worker 未返回可解析的 JSON 结果。")
return payload
@staticmethod
def _parse_worker_stdout(stdout: str) -> dict | None:
for line in reversed(stdout.splitlines()):
normalized = line.strip()
if normalized.startswith(WORKER_JSON_PREFIX):
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
return None
def _prepare_pdf_inputs(
self,
*,
pdf_path: Path,
filename: str,
media_type: str,
cleanup_paths: list[Path],
) -> list[PreparedOcrInput]:
output_dir = pdf_path.with_suffix("")
output_dir.mkdir(parents=True, exist_ok=True)
cleanup_paths.append(output_dir)
image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir)
if not image_paths:
raise RuntimeError("PDF 转图片后未生成可识别页面。")
preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png")
source_key = uuid4().hex
descriptors: list[PreparedOcrInput] = []
for page_index, image_path in enumerate(image_paths):
descriptors.append(
PreparedOcrInput(
input_path=image_path,
source_key=source_key,
filename=filename,
media_type=media_type,
page_index=page_index,
preview_kind="image" if page_index == 0 else "",
preview_data_url=preview_data_url if page_index == 0 else "",
)
)
return descriptors
def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
prefix = output_dir / "page"
completed = subprocess.run(
[
"pdftoppm",
"-png",
"-r",
"160",
str(pdf_path),
str(prefix),
],
capture_output=True,
text=True,
timeout=self.settings.ocr_timeout_seconds,
check=False,
)
if completed.returncode != 0:
detail = (completed.stderr or completed.stdout or "").strip()
raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}")
return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key)
@staticmethod
def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]:
suffix = path.stem.rsplit("-", 1)[-1]
try:
return int(suffix), path.name
except ValueError:
return 0, path.name
@staticmethod
def _build_preview_data_url(path: Path, *, media_type: str) -> str:
encoded = base64.b64encode(path.read_bytes()).decode("ascii")
return f"data:{media_type};base64,{encoded}"
def _build_documents(
self,
*,
worker_documents: list[dict],
prepared_inputs: list[PreparedOcrInput],
) -> list[OcrRecognizeDocumentRead]:
descriptor_by_path = {str(item.input_path): item for item in prepared_inputs}
source_order: list[str] = []
seen_sources: set[str] = set()
for item in prepared_inputs:
if item.source_key in seen_sources:
continue
seen_sources.add(item.source_key)
source_order.append(item.source_key)
aggregated_by_source: dict[str, AggregatedOcrDocument] = {}
for payload in worker_documents:
if not isinstance(payload, dict):
continue
input_path = str(payload.get("input_path") or "")
descriptor = descriptor_by_path.get(input_path)
if descriptor is None:
continue
aggregated = aggregated_by_source.get(descriptor.source_key)
if aggregated is None:
aggregated = AggregatedOcrDocument(
filename=descriptor.filename,
media_type=descriptor.media_type,
source_key=descriptor.source_key,
engine=str(payload.get("engine", "paddleocr_mobile")),
model=str(payload.get("model", "PP-OCRv5_mobile")),
)
aggregated_by_source[descriptor.source_key] = aggregated
aggregated.page_count = max(
aggregated.page_count,
(descriptor.page_index + 1)
if descriptor.page_index is not None
else int(payload.get("page_count", 1) or 1),
)
if descriptor.preview_kind and not aggregated.preview_kind:
aggregated.preview_kind = descriptor.preview_kind
if descriptor.preview_data_url and not aggregated.preview_data_url:
aggregated.preview_data_url = descriptor.preview_data_url
page_summary = str(payload.get("summary", "") or "").strip()
if page_summary:
aggregated.summary_fragments.append(page_summary)
page_text = str(payload.get("text", "") or "").strip()
if page_text:
aggregated.text_fragments.append(page_text)
lines = self._build_lines(
payload.get("lines", []),
page_index_override=descriptor.page_index,
)
aggregated.lines.extend(lines)
aggregated.score_values.extend(line.score for line in lines if line.score > 0)
if not lines:
avg_score = float(payload.get("avg_score", 0.0) or 0.0)
if avg_score > 0:
aggregated.score_values.append(avg_score)
for warning in payload.get("warnings", []):
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in aggregated.warnings:
aggregated.warnings.append(normalized_warning)
documents: list[OcrRecognizeDocumentRead] = []
for source_key in source_order:
descriptors = [item for item in prepared_inputs if item.source_key == source_key]
if not descriptors:
continue
aggregated = aggregated_by_source.get(source_key)
if aggregated is None:
first_descriptor = descriptors[0]
documents.append(
OcrRecognizeDocumentRead(
filename=first_descriptor.filename,
media_type=first_descriptor.media_type,
page_count=max(1, len(descriptors)),
preview_kind=first_descriptor.preview_kind,
preview_data_url=first_descriptor.preview_data_url,
warnings=["OCR worker 未返回该文件的识别结果。"],
)
)
continue
documents.append(self._finalize_document(aggregated))
return documents
@staticmethod
def _build_lines(
items: list[dict],
*,
page_index_override: int | None = None,
) -> list[OcrRecognizeLineRead]:
lines: list[OcrRecognizeLineRead] = []
for item in items:
if not isinstance(item, dict):
continue
page_index = page_index_override
if page_index is None and item.get("page_index") is not None:
page_index = int(item["page_index"])
lines.append(
OcrRecognizeLineRead(
text=str(item.get("text", "")),
score=float(item.get("score", 0.0) or 0.0),
box=[
[int(point[0]), int(point[1])]
for point in item.get("box", [])
if isinstance(point, list) and len(point) == 2
],
page_index=page_index,
)
)
return lines
@staticmethod
def _truncate_summary(parts: list[str]) -> str:
summary = "".join([part for part in parts if part][:3])
if len(summary) > 180:
return f"{summary[:177]}..."
return summary
def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead:
full_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip()
summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments)
insight = self.document_intelligence_service.build_document_insight(
filename=aggregated.filename,
summary=summary,
text=full_text,
preview_data_url=aggregated.preview_data_url,
)
warnings = list(aggregated.warnings)
for warning in insight.warnings:
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in warnings:
warnings.append(normalized_warning)
return OcrRecognizeDocumentRead(
filename=aggregated.filename,
media_type=aggregated.media_type,
engine=aggregated.engine,
model=aggregated.model,
text=full_text,
summary=summary,
avg_score=(
sum(aggregated.score_values) / len(aggregated.score_values)
if aggregated.score_values
else 0.0
),
line_count=len(aggregated.lines),
page_count=max(1, aggregated.page_count),
document_type=insight.document_type,
document_type_label=insight.document_type_label,
scene_code=insight.scene_code,
scene_label=insight.scene_label,
classification_source=insight.classification_source,
classification_confidence=insight.classification_confidence,
classification_evidence=list(insight.evidence),
document_fields=[
OcrRecognizeFieldRead(
key=field.key,
label=field.label,
value=field.value,
)
for field in insight.fields
],
preview_kind=aggregated.preview_kind,
preview_data_url=aggregated.preview_data_url,
warnings=warnings,
lines=sorted(
aggregated.lines,
key=lambda item: item.page_index if item.page_index is not None else -1,
),
)
@staticmethod
def _cleanup_temp_paths(paths: list[Path]) -> None:
for path in reversed(paths):
if path.is_dir():
shutil.rmtree(path, ignore_errors=True)
continue
path.unlink(missing_ok=True)