feat(server): 新增文档智能识别服务,扩展OCR接口支持 Azure Document Intelligence

This commit is contained in:
caoxiaozhu
2026-05-14 09:32:15 +00:00
parent 8adeefe4a9
commit 8b39f48dec
7 changed files with 1128 additions and 61 deletions

View File

@@ -1,21 +1,55 @@
from __future__ import annotations
import base64
import json
import shutil
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from uuid import uuid4
from sqlalchemy.orm import Session
from app.core.config import SERVER_DIR, get_settings
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
from app.services.document_intelligence import DocumentIntelligenceService
WORKER_JSON_PREFIX = "__OCR_JSON__="
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
@dataclass(slots=True)
class PreparedOcrInput:
input_path: Path
source_key: str
filename: str
media_type: str
page_index: int | None = None
preview_kind: str = ""
preview_data_url: str = ""
@dataclass(slots=True)
class AggregatedOcrDocument:
filename: str
media_type: str
source_key: str
engine: str = "paddleocr_mobile"
model: str = "PP-OCRv5_mobile"
summary_fragments: list[str] = field(default_factory=list)
text_fragments: list[str] = field(default_factory=list)
score_values: list[float] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
lines: list[OcrRecognizeLineRead] = field(default_factory=list)
page_count: int = 0
preview_kind: str = ""
preview_data_url: str = ""
class OcrService:
def __init__(self) -> None:
def __init__(self, db: Session | None = None) -> None:
self.settings = get_settings()
self.document_intelligence_service = DocumentIntelligenceService(db)
def recognize_files(
self,
@@ -28,10 +62,11 @@ class OcrService:
temp_root.mkdir(parents=True, exist_ok=True)
documents: list[OcrRecognizeDocumentRead] = []
input_paths: list[Path] = []
meta_by_path: dict[str, tuple[str, str]] = {}
prepared_inputs: list[PreparedOcrInput] = []
cleanup_paths: list[Path] = []
python_bin = self._resolve_python_bin()
worker_path = self._resolve_worker_path()
worker_payload: dict = {}
try:
for filename, content, media_type in files:
@@ -73,17 +108,55 @@ class OcrService:
temp_path = temp_root / f"{uuid4().hex}{suffix}"
temp_path.write_bytes(content)
input_paths.append(temp_path)
meta_by_path[str(temp_path)] = (normalized_name, resolved_media_type)
cleanup_paths.append(temp_path)
if input_paths:
if suffix == ".pdf":
try:
prepared_inputs.extend(
self._prepare_pdf_inputs(
pdf_path=temp_path,
filename=normalized_name,
media_type=resolved_media_type,
cleanup_paths=cleanup_paths,
)
)
except RuntimeError as exc:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=[str(exc)],
)
)
continue
prepared_inputs.append(
PreparedOcrInput(
input_path=temp_path,
source_key=uuid4().hex,
filename=normalized_name,
media_type=resolved_media_type,
preview_kind="image" if resolved_media_type.startswith("image/") else "",
preview_data_url=(
self._build_preview_data_url(temp_path, media_type=resolved_media_type)
if resolved_media_type.startswith("image/")
else ""
),
)
)
if prepared_inputs:
worker_payload = self._invoke_worker(
python_bin=python_bin,
worker_path=worker_path,
input_paths=input_paths,
input_paths=[item.input_path for item in prepared_inputs],
)
documents.extend(
self._build_documents(
worker_documents=worker_payload.get("documents", []),
prepared_inputs=prepared_inputs,
)
)
for item in worker_payload.get("documents", []):
documents.append(self._build_document(item, meta_by_path))
success_count = sum(
1
@@ -92,12 +165,12 @@ class OcrService:
)
engine = (
str(worker_payload.get("engine", "paddleocr_mobile"))
if input_paths
if prepared_inputs
else "paddleocr_mobile"
)
model = (
str(worker_payload.get("model", "PP-OCRv5_mobile"))
if input_paths
if prepared_inputs
else "PP-OCRv5_mobile"
)
return OcrRecognizeBatchRead(
@@ -108,8 +181,7 @@ class OcrService:
documents=documents,
)
finally:
for path in input_paths:
path.unlink(missing_ok=True)
self._cleanup_temp_paths(cleanup_paths)
def _resolve_python_bin(self) -> str:
candidates = []
@@ -182,40 +254,258 @@ class OcrService:
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
return None
@staticmethod
def _build_document(
payload: dict,
meta_by_path: dict[str, tuple[str, str]],
) -> OcrRecognizeDocumentRead:
input_path = str(payload.get("input_path") or "")
filename, media_type = meta_by_path.get(
input_path,
(Path(input_path).name or "upload.bin", "application/octet-stream"),
)
lines = [
OcrRecognizeLineRead(
text=str(item.get("text", "")),
score=float(item.get("score", 0.0) or 0.0),
box=[
[int(point[0]), int(point[1])]
for point in item.get("box", [])
if isinstance(point, list) and len(point) == 2
],
page_index=int(item["page_index"]) if item.get("page_index") is not None else None,
def _prepare_pdf_inputs(
self,
*,
pdf_path: Path,
filename: str,
media_type: str,
cleanup_paths: list[Path],
) -> list[PreparedOcrInput]:
output_dir = pdf_path.with_suffix("")
output_dir.mkdir(parents=True, exist_ok=True)
cleanup_paths.append(output_dir)
image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir)
if not image_paths:
raise RuntimeError("PDF 转图片后未生成可识别页面。")
preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png")
source_key = uuid4().hex
descriptors: list[PreparedOcrInput] = []
for page_index, image_path in enumerate(image_paths):
descriptors.append(
PreparedOcrInput(
input_path=image_path,
source_key=source_key,
filename=filename,
media_type=media_type,
page_index=page_index,
preview_kind="image" if page_index == 0 else "",
preview_data_url=preview_data_url if page_index == 0 else "",
)
)
for item in payload.get("lines", [])
if isinstance(item, dict)
]
return OcrRecognizeDocumentRead(
filename=filename,
media_type=media_type,
engine=str(payload.get("engine", "paddleocr_mobile")),
model=str(payload.get("model", "PP-OCRv5_mobile")),
text=str(payload.get("text", "")),
summary=str(payload.get("summary", "")),
avg_score=float(payload.get("avg_score", 0.0) or 0.0),
line_count=int(payload.get("line_count", len(lines)) or 0),
page_count=int(payload.get("page_count", 1) or 1),
warnings=[str(item) for item in payload.get("warnings", [])],
lines=lines,
return descriptors
def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
prefix = output_dir / "page"
completed = subprocess.run(
[
"pdftoppm",
"-png",
"-r",
"160",
str(pdf_path),
str(prefix),
],
capture_output=True,
text=True,
timeout=self.settings.ocr_timeout_seconds,
check=False,
)
if completed.returncode != 0:
detail = (completed.stderr or completed.stdout or "").strip()
raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}")
return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key)
@staticmethod
def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]:
suffix = path.stem.rsplit("-", 1)[-1]
try:
return int(suffix), path.name
except ValueError:
return 0, path.name
@staticmethod
def _build_preview_data_url(path: Path, *, media_type: str) -> str:
encoded = base64.b64encode(path.read_bytes()).decode("ascii")
return f"data:{media_type};base64,{encoded}"
def _build_documents(
self,
*,
worker_documents: list[dict],
prepared_inputs: list[PreparedOcrInput],
) -> list[OcrRecognizeDocumentRead]:
descriptor_by_path = {str(item.input_path): item for item in prepared_inputs}
source_order: list[str] = []
seen_sources: set[str] = set()
for item in prepared_inputs:
if item.source_key in seen_sources:
continue
seen_sources.add(item.source_key)
source_order.append(item.source_key)
aggregated_by_source: dict[str, AggregatedOcrDocument] = {}
for payload in worker_documents:
if not isinstance(payload, dict):
continue
input_path = str(payload.get("input_path") or "")
descriptor = descriptor_by_path.get(input_path)
if descriptor is None:
continue
aggregated = aggregated_by_source.get(descriptor.source_key)
if aggregated is None:
aggregated = AggregatedOcrDocument(
filename=descriptor.filename,
media_type=descriptor.media_type,
source_key=descriptor.source_key,
engine=str(payload.get("engine", "paddleocr_mobile")),
model=str(payload.get("model", "PP-OCRv5_mobile")),
)
aggregated_by_source[descriptor.source_key] = aggregated
aggregated.page_count = max(
aggregated.page_count,
(descriptor.page_index + 1)
if descriptor.page_index is not None
else int(payload.get("page_count", 1) or 1),
)
if descriptor.preview_kind and not aggregated.preview_kind:
aggregated.preview_kind = descriptor.preview_kind
if descriptor.preview_data_url and not aggregated.preview_data_url:
aggregated.preview_data_url = descriptor.preview_data_url
page_summary = str(payload.get("summary", "") or "").strip()
if page_summary:
aggregated.summary_fragments.append(page_summary)
page_text = str(payload.get("text", "") or "").strip()
if page_text:
aggregated.text_fragments.append(page_text)
lines = self._build_lines(
payload.get("lines", []),
page_index_override=descriptor.page_index,
)
aggregated.lines.extend(lines)
aggregated.score_values.extend(line.score for line in lines if line.score > 0)
if not lines:
avg_score = float(payload.get("avg_score", 0.0) or 0.0)
if avg_score > 0:
aggregated.score_values.append(avg_score)
for warning in payload.get("warnings", []):
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in aggregated.warnings:
aggregated.warnings.append(normalized_warning)
documents: list[OcrRecognizeDocumentRead] = []
for source_key in source_order:
descriptors = [item for item in prepared_inputs if item.source_key == source_key]
if not descriptors:
continue
aggregated = aggregated_by_source.get(source_key)
if aggregated is None:
first_descriptor = descriptors[0]
documents.append(
OcrRecognizeDocumentRead(
filename=first_descriptor.filename,
media_type=first_descriptor.media_type,
page_count=max(1, len(descriptors)),
preview_kind=first_descriptor.preview_kind,
preview_data_url=first_descriptor.preview_data_url,
warnings=["OCR worker 未返回该文件的识别结果。"],
)
)
continue
documents.append(self._finalize_document(aggregated))
return documents
@staticmethod
def _build_lines(
items: list[dict],
*,
page_index_override: int | None = None,
) -> list[OcrRecognizeLineRead]:
lines: list[OcrRecognizeLineRead] = []
for item in items:
if not isinstance(item, dict):
continue
page_index = page_index_override
if page_index is None and item.get("page_index") is not None:
page_index = int(item["page_index"])
lines.append(
OcrRecognizeLineRead(
text=str(item.get("text", "")),
score=float(item.get("score", 0.0) or 0.0),
box=[
[int(point[0]), int(point[1])]
for point in item.get("box", [])
if isinstance(point, list) and len(point) == 2
],
page_index=page_index,
)
)
return lines
@staticmethod
def _truncate_summary(parts: list[str]) -> str:
summary = "".join([part for part in parts if part][:3])
if len(summary) > 180:
return f"{summary[:177]}..."
return summary
def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead:
full_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip()
summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments)
insight = self.document_intelligence_service.build_document_insight(
filename=aggregated.filename,
summary=summary,
text=full_text,
preview_data_url=aggregated.preview_data_url,
)
warnings = list(aggregated.warnings)
for warning in insight.warnings:
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in warnings:
warnings.append(normalized_warning)
return OcrRecognizeDocumentRead(
filename=aggregated.filename,
media_type=aggregated.media_type,
engine=aggregated.engine,
model=aggregated.model,
text=full_text,
summary=summary,
avg_score=(
sum(aggregated.score_values) / len(aggregated.score_values)
if aggregated.score_values
else 0.0
),
line_count=len(aggregated.lines),
page_count=max(1, aggregated.page_count),
document_type=insight.document_type,
document_type_label=insight.document_type_label,
scene_code=insight.scene_code,
scene_label=insight.scene_label,
classification_source=insight.classification_source,
classification_confidence=insight.classification_confidence,
classification_evidence=list(insight.evidence),
document_fields=[
OcrRecognizeFieldRead(
key=field.key,
label=field.label,
value=field.value,
)
for field in insight.fields
],
preview_kind=aggregated.preview_kind,
preview_data_url=aggregated.preview_data_url,
warnings=warnings,
lines=sorted(
aggregated.lines,
key=lambda item: item.page_index if item.page_index is not None else -1,
),
)
@staticmethod
def _cleanup_temp_paths(paths: list[Path]) -> None:
for path in reversed(paths):
if path.is_dir():
shutil.rmtree(path, ignore_errors=True)
continue
path.unlink(missing_ok=True)