feat(server): 新增文档智能识别服务,扩展OCR接口支持 Azure Document Intelligence
This commit is contained in:
66
server/tests/test_document_intelligence.py
Normal file
66
server/tests/test_document_intelligence.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.services.document_intelligence import DocumentIntelligenceService, build_document_insight
|
||||
from app.services.runtime_chat import RuntimeChatService
|
||||
|
||||
|
||||
def test_build_document_insight_prefers_transport_for_didi_text_with_hotel_noise() -> None:
|
||||
insight = build_document_insight(
|
||||
filename="didi-trip.png",
|
||||
summary="滴滴出行行程单",
|
||||
text="滴滴出行电子发票 订单号 12345678 上车点 深圳湾 下车点 后海 全季酒店 里程 12.4 公里 金额 48 元",
|
||||
)
|
||||
|
||||
assert insight.document_type == "taxi_receipt"
|
||||
assert insight.document_type_label == "出租车/网约车票据"
|
||||
assert insight.scene_code == "transport"
|
||||
assert any(field.label == "金额" and field.value == "48元" for field in insight.fields)
|
||||
|
||||
|
||||
def test_document_intelligence_service_uses_vlm_result_when_preview_available(monkeypatch) -> None:
|
||||
calls: list[tuple[str, ...]] = []
|
||||
|
||||
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
|
||||
calls.append(slot_priority)
|
||||
if slot_priority == ("vlm",):
|
||||
assert isinstance(messages[1]["content"], list)
|
||||
return json.dumps(
|
||||
{
|
||||
"document_type": "taxi_receipt",
|
||||
"scene_code": "transport",
|
||||
"scene_label": "交通票据",
|
||||
"expense_type": "transport",
|
||||
"confidence": 0.91,
|
||||
"evidence": ["图片主体为滴滴行程单,OCR 中出现订单号、上车、下车等字段"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
|
||||
try:
|
||||
insight = DocumentIntelligenceService(session).build_document_insight(
|
||||
filename="mixed-noise.png",
|
||||
summary="OCR 混入酒店名称",
|
||||
text="全季酒店 滴滴出行 订单号 12345678 上车 下车 金额 52 元",
|
||||
preview_data_url="data:image/png;base64,ZmFrZQ==",
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
assert insight.document_type == "taxi_receipt"
|
||||
assert insight.classification_source == "llm_vision"
|
||||
assert calls[0] == ("vlm",)
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.pool import StaticPool
|
||||
from app.api.deps import get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
|
||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
|
||||
from app.services.ocr import OcrService
|
||||
|
||||
|
||||
@@ -50,14 +50,23 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None:
|
||||
OcrRecognizeDocumentRead(
|
||||
filename="invoice.png",
|
||||
media_type="image/png",
|
||||
text="发票金额 100 元",
|
||||
summary="发票金额 100 元",
|
||||
text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||
summary="增值税电子发票,金额 100 元。",
|
||||
avg_score=0.98,
|
||||
line_count=1,
|
||||
page_count=1,
|
||||
document_type="vat_invoice",
|
||||
document_type_label="增值税发票",
|
||||
scene_code="other",
|
||||
scene_label="通用发票",
|
||||
document_fields=[
|
||||
OcrRecognizeFieldRead(key="amount", label="金额", value="100元"),
|
||||
OcrRecognizeFieldRead(key="date", label="日期", value="2026-05-13"),
|
||||
OcrRecognizeFieldRead(key="invoice_number", label="票据号码", value="12345678"),
|
||||
],
|
||||
lines=[
|
||||
OcrRecognizeLineRead(
|
||||
text="发票金额 100 元",
|
||||
text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||
score=0.98,
|
||||
box=[[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||
page_index=0,
|
||||
@@ -81,4 +90,7 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None:
|
||||
assert payload["engine"] == "paddleocr_mobile"
|
||||
assert payload["success_count"] == 1
|
||||
assert payload["documents"][0]["filename"] == "invoice.png"
|
||||
assert payload["documents"][0]["summary"] == "发票金额 100 元"
|
||||
assert payload["documents"][0]["summary"] == "增值税电子发票,金额 100 元。"
|
||||
assert payload["documents"][0]["document_type"] == "vat_invoice"
|
||||
assert payload["documents"][0]["document_type_label"] == "增值税发票"
|
||||
assert payload["documents"][0]["document_fields"][0]["label"] == "金额"
|
||||
|
||||
@@ -26,15 +26,15 @@ for index, arg in enumerate(sys.argv):
|
||||
"input_path": input_path,
|
||||
"engine": "paddleocr_mobile",
|
||||
"model": "PP-OCRv5_mobile",
|
||||
"text": "发票金额 100 元",
|
||||
"summary": "发票金额 100 元",
|
||||
"text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||
"summary": "增值税电子发票,金额 100 元。",
|
||||
"avg_score": 0.98,
|
||||
"line_count": 1,
|
||||
"page_count": 1,
|
||||
"warnings": [],
|
||||
"lines": [
|
||||
{
|
||||
"text": "发票金额 100 元",
|
||||
"text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||
"score": 0.98,
|
||||
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||
"page_index": 0,
|
||||
@@ -74,10 +74,106 @@ print("__OCR_JSON__=" + json.dumps(payload, ensure_ascii=False))
|
||||
assert len(result.documents) == 2
|
||||
|
||||
recognized = next(item for item in result.documents if item.filename == "invoice.png")
|
||||
assert recognized.summary == "发票金额 100 元"
|
||||
assert recognized.summary == "增值税电子发票,金额 100 元。"
|
||||
assert recognized.line_count == 1
|
||||
assert recognized.lines[0].text == "发票金额 100 元"
|
||||
assert recognized.document_type == "vat_invoice"
|
||||
assert recognized.document_type_label == "增值税发票"
|
||||
assert any(field.label == "金额" and field.value == "100元" for field in recognized.document_fields)
|
||||
assert any(field.label == "票据号码" and field.value == "12345678" for field in recognized.document_fields)
|
||||
assert any(field.label == "日期" and field.value == "2026-05-13" for field in recognized.document_fields)
|
||||
assert recognized.lines[0].text == "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13"
|
||||
|
||||
skipped = next(item for item in result.documents if item.filename == "notes.txt")
|
||||
assert skipped.line_count == 0
|
||||
assert skipped.warnings == ["当前仅支持图片和 PDF 文件进行 OCR。"]
|
||||
|
||||
|
||||
def test_ocr_service_converts_pdf_to_images_and_returns_image_preview(
|
||||
monkeypatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
def fake_convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
|
||||
first = output_dir / "page-1.png"
|
||||
second = output_dir / "page-2.png"
|
||||
first.write_bytes(b"fake-page-1")
|
||||
second.write_bytes(b"fake-page-2")
|
||||
return [first, second]
|
||||
|
||||
def fake_invoke_worker(
|
||||
self,
|
||||
*,
|
||||
python_bin: str,
|
||||
worker_path: str,
|
||||
input_paths: list[Path],
|
||||
) -> dict:
|
||||
assert [path.name for path in input_paths] == ["page-1.png", "page-2.png"]
|
||||
return {
|
||||
"engine": "paddleocr_mobile",
|
||||
"model": "PP-OCRv5_mobile",
|
||||
"documents": [
|
||||
{
|
||||
"input_path": str(input_paths[0]),
|
||||
"engine": "paddleocr_mobile",
|
||||
"model": "PP-OCRv5_mobile",
|
||||
"text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元",
|
||||
"summary": "高铁票第一页",
|
||||
"avg_score": 0.97,
|
||||
"line_count": 1,
|
||||
"page_count": 1,
|
||||
"warnings": [],
|
||||
"lines": [
|
||||
{
|
||||
"text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元",
|
||||
"score": 0.97,
|
||||
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"input_path": str(input_paths[1]),
|
||||
"engine": "paddleocr_mobile",
|
||||
"model": "PP-OCRv5_mobile",
|
||||
"text": "乘车人 张三",
|
||||
"summary": "高铁票第二页",
|
||||
"avg_score": 0.94,
|
||||
"line_count": 1,
|
||||
"page_count": 1,
|
||||
"warnings": [],
|
||||
"lines": [
|
||||
{
|
||||
"text": "乘车人 张三",
|
||||
"score": 0.94,
|
||||
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
|
||||
monkeypatch.setattr(OcrService, "_resolve_python_bin", lambda self: "python")
|
||||
monkeypatch.setattr(OcrService, "_resolve_worker_path", lambda self: "worker.py")
|
||||
monkeypatch.setattr(OcrService, "_convert_pdf_to_images", fake_convert_pdf_to_images)
|
||||
monkeypatch.setattr(OcrService, "_invoke_worker", fake_invoke_worker)
|
||||
get_settings.cache_clear()
|
||||
try:
|
||||
result = OcrService().recognize_files(
|
||||
[
|
||||
("train-ticket.pdf", b"%PDF-1.4 fake", "application/pdf"),
|
||||
]
|
||||
)
|
||||
finally:
|
||||
get_settings.cache_clear()
|
||||
|
||||
assert result.success_count == 1
|
||||
assert len(result.documents) == 1
|
||||
recognized = result.documents[0]
|
||||
assert recognized.filename == "train-ticket.pdf"
|
||||
assert recognized.page_count == 2
|
||||
assert recognized.preview_kind == "image"
|
||||
assert recognized.preview_data_url.startswith("data:image/png;base64,")
|
||||
assert recognized.document_type == "train_ticket"
|
||||
assert any(field.label == "金额" and field.value == "188元" for field in recognized.document_fields)
|
||||
assert any(field.label == "车次/航班" and field.value == "G1234" for field in recognized.document_fields)
|
||||
assert recognized.lines[0].page_index == 0
|
||||
assert recognized.lines[1].page_index == 1
|
||||
|
||||
Reference in New Issue
Block a user