Files
X-Financial/server/src/app/services/ocr.py

746 lines
29 KiB
Python
Raw Normal View History

from __future__ import annotations
import base64
import hashlib
import json
import re
import shutil
import subprocess
import threading
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from uuid import uuid4
from sqlalchemy.orm import Session
from app.core.config import SERVER_DIR, get_settings
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
from app.services.document_preview import DocumentPreviewAssets
from app.services.document_intelligence import DocumentIntelligenceService
WORKER_JSON_PREFIX = "__OCR_JSON__="
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
OCR_RESULT_CACHE_LIMIT = 32
OCR_RESULT_CACHE_PIPELINE_VERSION = f"pdf-image-ocr:{DocumentPreviewAssets.PDF_RENDERER_ID}:no-pdf-direct-v2"
@dataclass(slots=True)
class PreparedOcrInput:
input_path: Path
source_key: str
filename: str
media_type: str
page_index: int | None = None
preview_kind: str = ""
preview_data_url: str = ""
text_layer: str = ""
@dataclass(slots=True)
class AggregatedOcrDocument:
filename: str
media_type: str
source_key: str
engine: str = "paddleocr_mobile"
model: str = "PP-OCRv5_mobile"
summary_fragments: list[str] = field(default_factory=list)
text_fragments: list[str] = field(default_factory=list)
text_layer_fragments: list[str] = field(default_factory=list)
score_values: list[float] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
lines: list[OcrRecognizeLineRead] = field(default_factory=list)
page_count: int = 0
preview_kind: str = ""
preview_data_url: str = ""
class OcrService:
_cache_lock = threading.Lock()
_result_cache: OrderedDict[str, OcrRecognizeDocumentRead] = OrderedDict()
_worker_semaphore_lock = threading.Lock()
_worker_semaphore: threading.Semaphore | None = None
_worker_semaphore_limit = 0
def __init__(self, db: Session | None = None) -> None:
self.settings = get_settings()
self.document_intelligence_service = DocumentIntelligenceService(db)
def recognize_files(
self,
files: list[tuple[str, bytes, str | None]],
) -> OcrRecognizeBatchRead:
if not files:
raise ValueError("至少需要上传一个文件。")
temp_root = self.settings.resolved_ocr_temp_dir
temp_root.mkdir(parents=True, exist_ok=True)
documents: list[OcrRecognizeDocumentRead] = []
prepared_inputs: list[PreparedOcrInput] = []
cleanup_paths: list[Path] = []
worker_payload: dict = {}
cache_keys_by_source: dict[str, str] = {}
try:
for filename, content, media_type in files:
normalized_name = Path(str(filename or "").strip()).name or "upload.bin"
suffix = Path(normalized_name).suffix.lower()
resolved_media_type = str(media_type or "application/octet-stream")
if not content:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=["文件内容为空,未执行 OCR。"],
)
)
continue
if suffix not in SUPPORTED_SUFFIXES:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=["当前仅支持图片和 PDF 文件进行 OCR。"],
)
)
continue
if len(content) > self.settings.ocr_max_file_size_mb * 1024 * 1024:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=[
f"文件超过 {self.settings.ocr_max_file_size_mb} MB未执行 OCR。"
],
)
)
continue
cache_key = self._build_cache_key(content)
cached_document = self._read_cached_document(
cache_key,
filename=normalized_name,
media_type=resolved_media_type,
)
if cached_document is not None:
documents.append(cached_document)
continue
temp_path = temp_root / f"{uuid4().hex}{suffix}"
temp_path.write_bytes(content)
cleanup_paths.append(temp_path)
if suffix == ".pdf":
try:
text_layer = self._extract_pdf_text_layer(temp_path)
pdf_inputs = self._prepare_pdf_inputs(
pdf_path=temp_path,
filename=normalized_name,
media_type=resolved_media_type,
cleanup_paths=cleanup_paths,
text_layer=text_layer,
)
prepared_inputs.extend(pdf_inputs)
for item in pdf_inputs:
cache_keys_by_source.setdefault(item.source_key, cache_key)
except RuntimeError as exc:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=[str(exc)],
)
)
continue
source_key = uuid4().hex
prepared_inputs.append(
PreparedOcrInput(
input_path=temp_path,
source_key=source_key,
filename=normalized_name,
media_type=resolved_media_type,
preview_kind="image" if resolved_media_type.startswith("image/") else "",
preview_data_url=(
self._build_preview_data_url(temp_path, media_type=resolved_media_type)
if resolved_media_type.startswith("image/")
else ""
),
)
)
cache_keys_by_source[source_key] = cache_key
if prepared_inputs:
python_bin = self._resolve_python_bin()
worker_path = self._resolve_worker_path()
worker_payload = self._invoke_worker(
python_bin=python_bin,
worker_path=worker_path,
input_paths=[item.input_path for item in prepared_inputs],
)
recognized_documents = self._build_documents(
worker_documents=worker_payload.get("documents", []),
prepared_inputs=prepared_inputs,
)
documents.extend(recognized_documents)
self._write_cached_documents(
recognized_documents,
prepared_inputs=prepared_inputs,
cache_keys_by_source=cache_keys_by_source,
)
success_count = sum(
1
for item in documents
if item.line_count > 0 or not item.warnings
)
engine = (
str(worker_payload.get("engine", "paddleocr_mobile"))
if prepared_inputs
else "paddleocr_mobile"
)
model = (
str(worker_payload.get("model", "PP-OCRv5_mobile"))
if prepared_inputs
else "PP-OCRv5_mobile"
)
return OcrRecognizeBatchRead(
engine=engine,
model=model,
total_file_count=len(files),
success_count=success_count,
documents=documents,
)
finally:
self._cleanup_temp_paths(cleanup_paths)
def _resolve_python_bin(self) -> str:
candidates = []
configured = str(self.settings.ocr_python_bin or "").strip()
if configured:
candidates.append(configured)
candidates.append(str(SERVER_DIR / ".venv-ocr312" / "bin" / "python"))
candidates.append("/usr/local/bin/python3.12")
resolved = shutil.which("python3.12")
if resolved:
candidates.append(resolved)
for candidate in candidates:
if candidate and Path(candidate).exists():
return candidate
raise RuntimeError(
"未找到可用的 OCR Python 运行时。请先执行 scripts/bootstrap_paddleocr_mobile.sh "
"或通过 OCR_PYTHON_BIN 指向已安装 PaddleOCR 的 Python 3.12。"
)
@staticmethod
def _resolve_worker_path() -> str:
worker_path = SERVER_DIR / "scripts" / "paddle_ocr_worker.py"
if not worker_path.exists():
raise RuntimeError(f"OCR worker 不存在:{worker_path}")
return str(worker_path)
def _build_cache_key(self, content: bytes) -> str:
digest = hashlib.sha256(content).hexdigest()
return "|".join(
[
OCR_RESULT_CACHE_PIPELINE_VERSION,
self.settings.ocr_language,
self.settings.ocr_device,
self.settings.ocr_text_detection_model,
self.settings.ocr_text_recognition_model,
digest,
]
)
@classmethod
def _read_cached_document(
cls,
cache_key: str,
*,
filename: str,
media_type: str,
) -> OcrRecognizeDocumentRead | None:
if not cache_key:
return None
with cls._cache_lock:
cached = cls._result_cache.get(cache_key)
if cached is None:
return None
cls._result_cache.move_to_end(cache_key)
return cached.model_copy(update={"filename": filename, "media_type": media_type})
@classmethod
def _write_cached_documents(
cls,
documents: list[OcrRecognizeDocumentRead],
*,
prepared_inputs: list[PreparedOcrInput],
cache_keys_by_source: dict[str, str],
) -> None:
if not documents or not cache_keys_by_source:
return
source_order: list[str] = []
seen_sources: set[str] = set()
for item in prepared_inputs:
if item.source_key in seen_sources:
continue
seen_sources.add(item.source_key)
source_order.append(item.source_key)
with cls._cache_lock:
for source_key, document in zip(source_order, documents, strict=False):
cache_key = cache_keys_by_source.get(source_key, "")
if not cache_key:
continue
cls._result_cache[cache_key] = document.model_copy(
update={
"receipt_id": "",
"receipt_status": "",
"receipt_preview_url": "",
"receipt_source_url": "",
}
)
cls._result_cache.move_to_end(cache_key)
while len(cls._result_cache) > OCR_RESULT_CACHE_LIMIT:
cls._result_cache.popitem(last=False)
@classmethod
def _write_cached_document(cls, cache_key: str, document: OcrRecognizeDocumentRead) -> None:
if not cache_key:
return
with cls._cache_lock:
cls._result_cache[cache_key] = document.model_copy(
update={
"receipt_id": "",
"receipt_status": "",
"receipt_preview_url": "",
"receipt_source_url": "",
}
)
cls._result_cache.move_to_end(cache_key)
while len(cls._result_cache) > OCR_RESULT_CACHE_LIMIT:
cls._result_cache.popitem(last=False)
@classmethod
def _resolve_worker_semaphore(cls, limit: int) -> threading.Semaphore:
normalized_limit = max(1, int(limit or 1))
with cls._worker_semaphore_lock:
if cls._worker_semaphore is None or cls._worker_semaphore_limit != normalized_limit:
cls._worker_semaphore = threading.Semaphore(normalized_limit)
cls._worker_semaphore_limit = normalized_limit
return cls._worker_semaphore
def _invoke_worker(
self,
*,
python_bin: str,
worker_path: str,
input_paths: list[Path],
) -> dict:
command = [
python_bin,
worker_path,
"--lang",
self.settings.ocr_language,
"--text-detection-model",
self.settings.ocr_text_detection_model,
"--text-recognition-model",
self.settings.ocr_text_recognition_model,
]
configured_device = str(self.settings.ocr_device or "").strip()
if configured_device:
command.extend(["--device", configured_device])
for path in input_paths:
command.extend(["--input", str(path)])
semaphore = self._resolve_worker_semaphore(self.settings.ocr_max_concurrent_workers)
with semaphore:
completed = subprocess.run(
command,
capture_output=True,
text=True,
timeout=self.settings.ocr_timeout_seconds,
check=False,
)
if completed.returncode != 0:
detail = (completed.stderr or completed.stdout or "").strip()
raise RuntimeError(f"OCR 执行失败:{detail or 'worker 返回非 0 状态码。'}")
payload = self._parse_worker_stdout(completed.stdout)
if payload is None:
raise RuntimeError("OCR worker 未返回可解析的 JSON 结果。")
return payload
@staticmethod
def _parse_worker_stdout(stdout: str) -> dict | None:
for line in reversed(stdout.splitlines()):
normalized = line.strip()
if normalized.startswith(WORKER_JSON_PREFIX):
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
return None
def _prepare_pdf_inputs(
self,
*,
pdf_path: Path,
filename: str,
media_type: str,
cleanup_paths: list[Path],
text_layer: str = "",
) -> list[PreparedOcrInput]:
output_dir = pdf_path.with_suffix("")
output_dir.mkdir(parents=True, exist_ok=True)
cleanup_paths.append(output_dir)
image_paths, preview_usable = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir)
if not image_paths:
raise RuntimeError("PDF 转图片后未生成可识别页面。")
preview_data_url = (
self._build_preview_data_url(image_paths[0], media_type="image/png")
if preview_usable
else ""
)
source_key = uuid4().hex
descriptors: list[PreparedOcrInput] = []
for page_index, image_path in enumerate(image_paths):
descriptors.append(
PreparedOcrInput(
input_path=image_path,
source_key=source_key,
filename=filename,
media_type=media_type,
page_index=page_index,
preview_kind="image" if page_index == 0 and preview_data_url else "",
preview_data_url=preview_data_url if page_index == 0 else "",
text_layer=text_layer if page_index == 0 else "",
)
)
return descriptors
def _extract_pdf_text_layer(self, pdf_path: Path) -> str:
try:
completed = subprocess.run(
[
"pdftotext",
"-layout",
str(pdf_path),
"-",
],
capture_output=True,
text=True,
timeout=self.settings.ocr_timeout_seconds,
check=False,
)
except (OSError, subprocess.SubprocessError, UnicodeError):
return ""
if completed.returncode != 0:
return ""
return self._normalize_extracted_text(completed.stdout)
def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> tuple[list[Path], bool]:
try:
pages = DocumentPreviewAssets.render_pdf_pages(
pdf_path=pdf_path,
output_dir=output_dir,
timeout_seconds=self.settings.ocr_timeout_seconds,
)
except RuntimeError as exc:
raise RuntimeError(f"PDF 转图片失败:{exc}") from exc
return pages, True
@staticmethod
def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]:
suffix = path.stem.rsplit("-", 1)[-1]
try:
return int(suffix), path.name
except ValueError:
return 0, path.name
@staticmethod
def _build_preview_data_url(path: Path, *, media_type: str) -> str:
encoded = base64.b64encode(path.read_bytes()).decode("ascii")
return f"data:{media_type};base64,{encoded}"
def _build_documents(
self,
*,
worker_documents: list[dict],
prepared_inputs: list[PreparedOcrInput],
) -> list[OcrRecognizeDocumentRead]:
descriptor_by_path = {str(item.input_path): item for item in prepared_inputs}
source_order: list[str] = []
seen_sources: set[str] = set()
for item in prepared_inputs:
if item.source_key in seen_sources:
continue
seen_sources.add(item.source_key)
source_order.append(item.source_key)
aggregated_by_source: dict[str, AggregatedOcrDocument] = {}
for payload in worker_documents:
if not isinstance(payload, dict):
continue
input_path = str(payload.get("input_path") or "")
descriptor = descriptor_by_path.get(input_path)
if descriptor is None:
continue
aggregated = aggregated_by_source.get(descriptor.source_key)
if aggregated is None:
aggregated = AggregatedOcrDocument(
filename=descriptor.filename,
media_type=descriptor.media_type,
source_key=descriptor.source_key,
engine=str(payload.get("engine", "paddleocr_mobile")),
model=str(payload.get("model", "PP-OCRv5_mobile")),
)
aggregated_by_source[descriptor.source_key] = aggregated
aggregated.page_count = max(
aggregated.page_count,
(descriptor.page_index + 1)
if descriptor.page_index is not None
else int(payload.get("page_count", 1) or 1),
)
if descriptor.preview_kind and not aggregated.preview_kind:
aggregated.preview_kind = descriptor.preview_kind
if descriptor.preview_data_url and not aggregated.preview_data_url:
aggregated.preview_data_url = descriptor.preview_data_url
if descriptor.text_layer and descriptor.text_layer not in aggregated.text_layer_fragments:
aggregated.text_layer_fragments.append(descriptor.text_layer)
page_summary = str(payload.get("summary", "") or "").strip()
if page_summary:
aggregated.summary_fragments.append(page_summary)
page_text = self._resolve_worker_document_text(payload)
if page_text:
aggregated.text_fragments.append(page_text)
lines = self._build_lines(
payload.get("lines", []),
page_index_override=descriptor.page_index,
)
aggregated.lines.extend(lines)
aggregated.score_values.extend(line.score for line in lines if line.score > 0)
if not lines:
avg_score = float(payload.get("avg_score", 0.0) or 0.0)
if avg_score > 0:
aggregated.score_values.append(avg_score)
for warning in payload.get("warnings", []):
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in aggregated.warnings:
aggregated.warnings.append(normalized_warning)
documents: list[OcrRecognizeDocumentRead] = []
for source_key in source_order:
descriptors = [item for item in prepared_inputs if item.source_key == source_key]
if not descriptors:
continue
aggregated = aggregated_by_source.get(source_key)
if aggregated is None:
first_descriptor = descriptors[0]
text_layer = self._collect_descriptor_text_layer(descriptors)
if text_layer:
fallback = AggregatedOcrDocument(
filename=first_descriptor.filename,
media_type=first_descriptor.media_type,
source_key=first_descriptor.source_key,
page_count=max(1, len(descriptors)),
preview_kind=first_descriptor.preview_kind,
preview_data_url=first_descriptor.preview_data_url,
warnings=["OCR worker 未返回该文件的识别结果,已使用 PDF 文本层。"],
)
fallback.text_layer_fragments.append(text_layer)
documents.append(self._finalize_document(fallback))
continue
documents.append(
OcrRecognizeDocumentRead(
filename=first_descriptor.filename,
media_type=first_descriptor.media_type,
page_count=max(1, len(descriptors)),
preview_kind=first_descriptor.preview_kind,
preview_data_url=first_descriptor.preview_data_url,
warnings=["OCR worker 未返回该文件的识别结果。"],
)
)
continue
documents.append(self._finalize_document(aggregated))
return documents
@staticmethod
def _collect_descriptor_text_layer(descriptors: list[PreparedOcrInput]) -> str:
for descriptor in descriptors:
if descriptor.text_layer:
return descriptor.text_layer
return ""
@staticmethod
def _resolve_worker_document_text(payload: dict) -> str:
for key in ("text", "ocr_text", "raw_text", "full_text"):
value = str(payload.get(key, "") or "").strip()
if value:
return value
lines = payload.get("lines", [])
if not isinstance(lines, list):
return ""
return "\n".join(
str(item.get("text", "") or "").strip()
for item in lines
if isinstance(item, dict) and str(item.get("text", "") or "").strip()
).strip()
@staticmethod
def _build_lines(
items: list[dict],
*,
page_index_override: int | None = None,
) -> list[OcrRecognizeLineRead]:
lines: list[OcrRecognizeLineRead] = []
for item in items:
if not isinstance(item, dict):
continue
page_index = page_index_override
if page_index is None and item.get("page_index") is not None:
page_index = int(item["page_index"])
lines.append(
OcrRecognizeLineRead(
text=str(item.get("text", "")),
score=float(item.get("score", 0.0) or 0.0),
box=[
[int(point[0]), int(point[1])]
for point in item.get("box", [])
if isinstance(point, list) and len(point) == 2
],
page_index=page_index,
)
)
return lines
@staticmethod
def _truncate_summary(parts: list[str]) -> str:
summary = "".join([part for part in parts if part][:3])
if len(summary) > 180:
return f"{summary[:177]}..."
return summary
def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead:
ocr_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip()
text_layer = "\n".join(fragment for fragment in aggregated.text_layer_fragments if fragment).strip()
full_text, used_text_layer = self._choose_document_text(ocr_text=ocr_text, text_layer=text_layer)
summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments)
if used_text_layer or self._placeholder_ratio(summary) >= 0.12:
summary = self._summarize_text(full_text)
preview_kind = aggregated.preview_kind
preview_data_url = aggregated.preview_data_url
insight = self.document_intelligence_service.build_document_insight(
filename=aggregated.filename,
summary=summary,
text=full_text,
preview_data_url=preview_data_url,
)
warnings = list(aggregated.warnings)
for warning in insight.warnings:
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in warnings:
warnings.append(normalized_warning)
return OcrRecognizeDocumentRead(
filename=aggregated.filename,
media_type=aggregated.media_type,
engine=aggregated.engine,
model=aggregated.model,
text=full_text,
summary=summary,
avg_score=(
sum(aggregated.score_values) / len(aggregated.score_values)
if aggregated.score_values
else 0.0
),
line_count=len(aggregated.lines),
page_count=max(1, aggregated.page_count),
document_type=insight.document_type,
document_type_label=insight.document_type_label,
scene_code=insight.scene_code,
scene_label=insight.scene_label,
classification_source=insight.classification_source,
classification_confidence=insight.classification_confidence,
classification_evidence=list(insight.evidence),
document_fields=[
OcrRecognizeFieldRead(
key=field.key,
label=field.label,
value=field.value,
)
for field in insight.fields
],
preview_kind=preview_kind,
preview_data_url=preview_data_url,
warnings=warnings,
lines=sorted(
aggregated.lines,
key=lambda item: item.page_index if item.page_index is not None else -1,
),
)
@classmethod
def _choose_document_text(cls, *, ocr_text: str, text_layer: str) -> tuple[str, bool]:
normalized_ocr_text = cls._normalize_extracted_text(ocr_text)
normalized_text_layer = cls._normalize_extracted_text(text_layer)
if not normalized_text_layer:
return normalized_ocr_text, False
if not normalized_ocr_text:
return normalized_text_layer, True
if cls._placeholder_ratio(normalized_ocr_text) >= 0.12 and cls._meaningful_char_count(normalized_text_layer) >= 8:
return normalized_text_layer, True
if cls._meaningful_char_count(normalized_text_layer) > cls._meaningful_char_count(normalized_ocr_text) * 1.3:
return normalized_text_layer, True
return normalized_ocr_text, False
@staticmethod
def _normalize_extracted_text(value: str) -> str:
lines = [re.sub(r"[ \t]+", " ", line).strip() for line in str(value or "").replace("\r", "\n").split("\n")]
return "\n".join(line for line in lines if line).strip()
@staticmethod
def _summarize_text(value: str) -> str:
lines = [line.strip() for line in str(value or "").splitlines() if line.strip()]
summary = "".join(lines[:3])
if len(summary) > 180:
return f"{summary[:177]}..."
return summary
@staticmethod
def _meaningful_char_count(value: str) -> int:
return len(re.findall(r"[0-9A-Za-z\u4e00-\u9fff]", str(value or "")))
@staticmethod
def _placeholder_ratio(value: str) -> float:
chars = [char for char in str(value or "") if not char.isspace()]
if not chars:
return 0.0
placeholder_count = sum(1 for char in chars if char in {"", "<EFBFBD>"})
return placeholder_count / len(chars)
@staticmethod
def _cleanup_temp_paths(paths: list[Path]) -> None:
for path in reversed(paths):
if path.is_dir():
shutil.rmtree(path, ignore_errors=True)
continue
path.unlink(missing_ok=True)