feat(server): 新增文档智能识别服务,扩展OCR接口支持 Azure Document Intelligence
This commit is contained in:
@@ -1,21 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import SERVER_DIR, get_settings
|
||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
|
||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
|
||||
from app.services.document_intelligence import DocumentIntelligenceService
|
||||
|
||||
WORKER_JSON_PREFIX = "__OCR_JSON__="
|
||||
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PreparedOcrInput:
|
||||
input_path: Path
|
||||
source_key: str
|
||||
filename: str
|
||||
media_type: str
|
||||
page_index: int | None = None
|
||||
preview_kind: str = ""
|
||||
preview_data_url: str = ""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AggregatedOcrDocument:
|
||||
filename: str
|
||||
media_type: str
|
||||
source_key: str
|
||||
engine: str = "paddleocr_mobile"
|
||||
model: str = "PP-OCRv5_mobile"
|
||||
summary_fragments: list[str] = field(default_factory=list)
|
||||
text_fragments: list[str] = field(default_factory=list)
|
||||
score_values: list[float] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
lines: list[OcrRecognizeLineRead] = field(default_factory=list)
|
||||
page_count: int = 0
|
||||
preview_kind: str = ""
|
||||
preview_data_url: str = ""
|
||||
|
||||
|
||||
class OcrService:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, db: Session | None = None) -> None:
|
||||
self.settings = get_settings()
|
||||
self.document_intelligence_service = DocumentIntelligenceService(db)
|
||||
|
||||
def recognize_files(
|
||||
self,
|
||||
@@ -28,10 +62,11 @@ class OcrService:
|
||||
temp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
documents: list[OcrRecognizeDocumentRead] = []
|
||||
input_paths: list[Path] = []
|
||||
meta_by_path: dict[str, tuple[str, str]] = {}
|
||||
prepared_inputs: list[PreparedOcrInput] = []
|
||||
cleanup_paths: list[Path] = []
|
||||
python_bin = self._resolve_python_bin()
|
||||
worker_path = self._resolve_worker_path()
|
||||
worker_payload: dict = {}
|
||||
|
||||
try:
|
||||
for filename, content, media_type in files:
|
||||
@@ -73,17 +108,55 @@ class OcrService:
|
||||
|
||||
temp_path = temp_root / f"{uuid4().hex}{suffix}"
|
||||
temp_path.write_bytes(content)
|
||||
input_paths.append(temp_path)
|
||||
meta_by_path[str(temp_path)] = (normalized_name, resolved_media_type)
|
||||
cleanup_paths.append(temp_path)
|
||||
|
||||
if input_paths:
|
||||
if suffix == ".pdf":
|
||||
try:
|
||||
prepared_inputs.extend(
|
||||
self._prepare_pdf_inputs(
|
||||
pdf_path=temp_path,
|
||||
filename=normalized_name,
|
||||
media_type=resolved_media_type,
|
||||
cleanup_paths=cleanup_paths,
|
||||
)
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
documents.append(
|
||||
OcrRecognizeDocumentRead(
|
||||
filename=normalized_name,
|
||||
media_type=resolved_media_type,
|
||||
warnings=[str(exc)],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
prepared_inputs.append(
|
||||
PreparedOcrInput(
|
||||
input_path=temp_path,
|
||||
source_key=uuid4().hex,
|
||||
filename=normalized_name,
|
||||
media_type=resolved_media_type,
|
||||
preview_kind="image" if resolved_media_type.startswith("image/") else "",
|
||||
preview_data_url=(
|
||||
self._build_preview_data_url(temp_path, media_type=resolved_media_type)
|
||||
if resolved_media_type.startswith("image/")
|
||||
else ""
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if prepared_inputs:
|
||||
worker_payload = self._invoke_worker(
|
||||
python_bin=python_bin,
|
||||
worker_path=worker_path,
|
||||
input_paths=input_paths,
|
||||
input_paths=[item.input_path for item in prepared_inputs],
|
||||
)
|
||||
documents.extend(
|
||||
self._build_documents(
|
||||
worker_documents=worker_payload.get("documents", []),
|
||||
prepared_inputs=prepared_inputs,
|
||||
)
|
||||
)
|
||||
for item in worker_payload.get("documents", []):
|
||||
documents.append(self._build_document(item, meta_by_path))
|
||||
|
||||
success_count = sum(
|
||||
1
|
||||
@@ -92,12 +165,12 @@ class OcrService:
|
||||
)
|
||||
engine = (
|
||||
str(worker_payload.get("engine", "paddleocr_mobile"))
|
||||
if input_paths
|
||||
if prepared_inputs
|
||||
else "paddleocr_mobile"
|
||||
)
|
||||
model = (
|
||||
str(worker_payload.get("model", "PP-OCRv5_mobile"))
|
||||
if input_paths
|
||||
if prepared_inputs
|
||||
else "PP-OCRv5_mobile"
|
||||
)
|
||||
return OcrRecognizeBatchRead(
|
||||
@@ -108,8 +181,7 @@ class OcrService:
|
||||
documents=documents,
|
||||
)
|
||||
finally:
|
||||
for path in input_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
self._cleanup_temp_paths(cleanup_paths)
|
||||
|
||||
def _resolve_python_bin(self) -> str:
|
||||
candidates = []
|
||||
@@ -182,40 +254,258 @@ class OcrService:
|
||||
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_document(
|
||||
payload: dict,
|
||||
meta_by_path: dict[str, tuple[str, str]],
|
||||
) -> OcrRecognizeDocumentRead:
|
||||
input_path = str(payload.get("input_path") or "")
|
||||
filename, media_type = meta_by_path.get(
|
||||
input_path,
|
||||
(Path(input_path).name or "upload.bin", "application/octet-stream"),
|
||||
)
|
||||
lines = [
|
||||
OcrRecognizeLineRead(
|
||||
text=str(item.get("text", "")),
|
||||
score=float(item.get("score", 0.0) or 0.0),
|
||||
box=[
|
||||
[int(point[0]), int(point[1])]
|
||||
for point in item.get("box", [])
|
||||
if isinstance(point, list) and len(point) == 2
|
||||
],
|
||||
page_index=int(item["page_index"]) if item.get("page_index") is not None else None,
|
||||
def _prepare_pdf_inputs(
|
||||
self,
|
||||
*,
|
||||
pdf_path: Path,
|
||||
filename: str,
|
||||
media_type: str,
|
||||
cleanup_paths: list[Path],
|
||||
) -> list[PreparedOcrInput]:
|
||||
output_dir = pdf_path.with_suffix("")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
cleanup_paths.append(output_dir)
|
||||
|
||||
image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir)
|
||||
if not image_paths:
|
||||
raise RuntimeError("PDF 转图片后未生成可识别页面。")
|
||||
|
||||
preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png")
|
||||
source_key = uuid4().hex
|
||||
descriptors: list[PreparedOcrInput] = []
|
||||
for page_index, image_path in enumerate(image_paths):
|
||||
descriptors.append(
|
||||
PreparedOcrInput(
|
||||
input_path=image_path,
|
||||
source_key=source_key,
|
||||
filename=filename,
|
||||
media_type=media_type,
|
||||
page_index=page_index,
|
||||
preview_kind="image" if page_index == 0 else "",
|
||||
preview_data_url=preview_data_url if page_index == 0 else "",
|
||||
)
|
||||
)
|
||||
for item in payload.get("lines", [])
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
return OcrRecognizeDocumentRead(
|
||||
filename=filename,
|
||||
media_type=media_type,
|
||||
engine=str(payload.get("engine", "paddleocr_mobile")),
|
||||
model=str(payload.get("model", "PP-OCRv5_mobile")),
|
||||
text=str(payload.get("text", "")),
|
||||
summary=str(payload.get("summary", "")),
|
||||
avg_score=float(payload.get("avg_score", 0.0) or 0.0),
|
||||
line_count=int(payload.get("line_count", len(lines)) or 0),
|
||||
page_count=int(payload.get("page_count", 1) or 1),
|
||||
warnings=[str(item) for item in payload.get("warnings", [])],
|
||||
lines=lines,
|
||||
return descriptors
|
||||
|
||||
def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
|
||||
prefix = output_dir / "page"
|
||||
completed = subprocess.run(
|
||||
[
|
||||
"pdftoppm",
|
||||
"-png",
|
||||
"-r",
|
||||
"160",
|
||||
str(pdf_path),
|
||||
str(prefix),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.settings.ocr_timeout_seconds,
|
||||
check=False,
|
||||
)
|
||||
if completed.returncode != 0:
|
||||
detail = (completed.stderr or completed.stdout or "").strip()
|
||||
raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}")
|
||||
|
||||
return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key)
|
||||
|
||||
@staticmethod
|
||||
def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]:
|
||||
suffix = path.stem.rsplit("-", 1)[-1]
|
||||
try:
|
||||
return int(suffix), path.name
|
||||
except ValueError:
|
||||
return 0, path.name
|
||||
|
||||
@staticmethod
|
||||
def _build_preview_data_url(path: Path, *, media_type: str) -> str:
|
||||
encoded = base64.b64encode(path.read_bytes()).decode("ascii")
|
||||
return f"data:{media_type};base64,{encoded}"
|
||||
|
||||
def _build_documents(
|
||||
self,
|
||||
*,
|
||||
worker_documents: list[dict],
|
||||
prepared_inputs: list[PreparedOcrInput],
|
||||
) -> list[OcrRecognizeDocumentRead]:
|
||||
descriptor_by_path = {str(item.input_path): item for item in prepared_inputs}
|
||||
source_order: list[str] = []
|
||||
seen_sources: set[str] = set()
|
||||
for item in prepared_inputs:
|
||||
if item.source_key in seen_sources:
|
||||
continue
|
||||
seen_sources.add(item.source_key)
|
||||
source_order.append(item.source_key)
|
||||
|
||||
aggregated_by_source: dict[str, AggregatedOcrDocument] = {}
|
||||
for payload in worker_documents:
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
input_path = str(payload.get("input_path") or "")
|
||||
descriptor = descriptor_by_path.get(input_path)
|
||||
if descriptor is None:
|
||||
continue
|
||||
|
||||
aggregated = aggregated_by_source.get(descriptor.source_key)
|
||||
if aggregated is None:
|
||||
aggregated = AggregatedOcrDocument(
|
||||
filename=descriptor.filename,
|
||||
media_type=descriptor.media_type,
|
||||
source_key=descriptor.source_key,
|
||||
engine=str(payload.get("engine", "paddleocr_mobile")),
|
||||
model=str(payload.get("model", "PP-OCRv5_mobile")),
|
||||
)
|
||||
aggregated_by_source[descriptor.source_key] = aggregated
|
||||
|
||||
aggregated.page_count = max(
|
||||
aggregated.page_count,
|
||||
(descriptor.page_index + 1)
|
||||
if descriptor.page_index is not None
|
||||
else int(payload.get("page_count", 1) or 1),
|
||||
)
|
||||
if descriptor.preview_kind and not aggregated.preview_kind:
|
||||
aggregated.preview_kind = descriptor.preview_kind
|
||||
if descriptor.preview_data_url and not aggregated.preview_data_url:
|
||||
aggregated.preview_data_url = descriptor.preview_data_url
|
||||
|
||||
page_summary = str(payload.get("summary", "") or "").strip()
|
||||
if page_summary:
|
||||
aggregated.summary_fragments.append(page_summary)
|
||||
|
||||
page_text = str(payload.get("text", "") or "").strip()
|
||||
if page_text:
|
||||
aggregated.text_fragments.append(page_text)
|
||||
|
||||
lines = self._build_lines(
|
||||
payload.get("lines", []),
|
||||
page_index_override=descriptor.page_index,
|
||||
)
|
||||
aggregated.lines.extend(lines)
|
||||
aggregated.score_values.extend(line.score for line in lines if line.score > 0)
|
||||
|
||||
if not lines:
|
||||
avg_score = float(payload.get("avg_score", 0.0) or 0.0)
|
||||
if avg_score > 0:
|
||||
aggregated.score_values.append(avg_score)
|
||||
|
||||
for warning in payload.get("warnings", []):
|
||||
normalized_warning = str(warning or "").strip()
|
||||
if normalized_warning and normalized_warning not in aggregated.warnings:
|
||||
aggregated.warnings.append(normalized_warning)
|
||||
|
||||
documents: list[OcrRecognizeDocumentRead] = []
|
||||
for source_key in source_order:
|
||||
descriptors = [item for item in prepared_inputs if item.source_key == source_key]
|
||||
if not descriptors:
|
||||
continue
|
||||
aggregated = aggregated_by_source.get(source_key)
|
||||
if aggregated is None:
|
||||
first_descriptor = descriptors[0]
|
||||
documents.append(
|
||||
OcrRecognizeDocumentRead(
|
||||
filename=first_descriptor.filename,
|
||||
media_type=first_descriptor.media_type,
|
||||
page_count=max(1, len(descriptors)),
|
||||
preview_kind=first_descriptor.preview_kind,
|
||||
preview_data_url=first_descriptor.preview_data_url,
|
||||
warnings=["OCR worker 未返回该文件的识别结果。"],
|
||||
)
|
||||
)
|
||||
continue
|
||||
documents.append(self._finalize_document(aggregated))
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def _build_lines(
|
||||
items: list[dict],
|
||||
*,
|
||||
page_index_override: int | None = None,
|
||||
) -> list[OcrRecognizeLineRead]:
|
||||
lines: list[OcrRecognizeLineRead] = []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
page_index = page_index_override
|
||||
if page_index is None and item.get("page_index") is not None:
|
||||
page_index = int(item["page_index"])
|
||||
lines.append(
|
||||
OcrRecognizeLineRead(
|
||||
text=str(item.get("text", "")),
|
||||
score=float(item.get("score", 0.0) or 0.0),
|
||||
box=[
|
||||
[int(point[0]), int(point[1])]
|
||||
for point in item.get("box", [])
|
||||
if isinstance(point, list) and len(point) == 2
|
||||
],
|
||||
page_index=page_index,
|
||||
)
|
||||
)
|
||||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _truncate_summary(parts: list[str]) -> str:
|
||||
summary = ";".join([part for part in parts if part][:3])
|
||||
if len(summary) > 180:
|
||||
return f"{summary[:177]}..."
|
||||
return summary
|
||||
|
||||
def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead:
|
||||
full_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip()
|
||||
summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments)
|
||||
insight = self.document_intelligence_service.build_document_insight(
|
||||
filename=aggregated.filename,
|
||||
summary=summary,
|
||||
text=full_text,
|
||||
preview_data_url=aggregated.preview_data_url,
|
||||
)
|
||||
warnings = list(aggregated.warnings)
|
||||
for warning in insight.warnings:
|
||||
normalized_warning = str(warning or "").strip()
|
||||
if normalized_warning and normalized_warning not in warnings:
|
||||
warnings.append(normalized_warning)
|
||||
return OcrRecognizeDocumentRead(
|
||||
filename=aggregated.filename,
|
||||
media_type=aggregated.media_type,
|
||||
engine=aggregated.engine,
|
||||
model=aggregated.model,
|
||||
text=full_text,
|
||||
summary=summary,
|
||||
avg_score=(
|
||||
sum(aggregated.score_values) / len(aggregated.score_values)
|
||||
if aggregated.score_values
|
||||
else 0.0
|
||||
),
|
||||
line_count=len(aggregated.lines),
|
||||
page_count=max(1, aggregated.page_count),
|
||||
document_type=insight.document_type,
|
||||
document_type_label=insight.document_type_label,
|
||||
scene_code=insight.scene_code,
|
||||
scene_label=insight.scene_label,
|
||||
classification_source=insight.classification_source,
|
||||
classification_confidence=insight.classification_confidence,
|
||||
classification_evidence=list(insight.evidence),
|
||||
document_fields=[
|
||||
OcrRecognizeFieldRead(
|
||||
key=field.key,
|
||||
label=field.label,
|
||||
value=field.value,
|
||||
)
|
||||
for field in insight.fields
|
||||
],
|
||||
preview_kind=aggregated.preview_kind,
|
||||
preview_data_url=aggregated.preview_data_url,
|
||||
warnings=warnings,
|
||||
lines=sorted(
|
||||
aggregated.lines,
|
||||
key=lambda item: item.page_index if item.page_index is not None else -1,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_temp_paths(paths: list[Path]) -> None:
|
||||
for path in reversed(paths):
|
||||
if path.is_dir():
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
continue
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user