67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.services.document_intelligence import DocumentIntelligenceService, build_document_insight
|
||
from app.services.runtime_chat import RuntimeChatService
|
||
|
||
|
||
def test_build_document_insight_prefers_transport_for_didi_text_with_hotel_noise() -> None:
|
||
insight = build_document_insight(
|
||
filename="didi-trip.png",
|
||
summary="滴滴出行行程单",
|
||
text="滴滴出行电子发票 订单号 12345678 上车点 深圳湾 下车点 后海 全季酒店 里程 12.4 公里 金额 48 元",
|
||
)
|
||
|
||
assert insight.document_type == "taxi_receipt"
|
||
assert insight.document_type_label == "出租车/网约车票据"
|
||
assert insight.scene_code == "transport"
|
||
assert any(field.label == "金额" and field.value == "48元" for field in insight.fields)
|
||
|
||
|
||
def test_document_intelligence_service_uses_vlm_result_when_preview_available(monkeypatch) -> None:
|
||
calls: list[tuple[str, ...]] = []
|
||
|
||
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
|
||
calls.append(slot_priority)
|
||
if slot_priority == ("vlm",):
|
||
assert isinstance(messages[1]["content"], list)
|
||
return json.dumps(
|
||
{
|
||
"document_type": "taxi_receipt",
|
||
"scene_code": "transport",
|
||
"scene_label": "交通票据",
|
||
"expense_type": "transport",
|
||
"confidence": 0.91,
|
||
"evidence": ["图片主体为滴滴行程单,OCR 中出现订单号、上车、下车等字段"],
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
return None
|
||
|
||
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
|
||
|
||
engine = create_engine(
|
||
"sqlite+pysqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
|
||
try:
|
||
insight = DocumentIntelligenceService(session).build_document_insight(
|
||
filename="mixed-noise.png",
|
||
summary="OCR 混入酒店名称",
|
||
text="全季酒店 滴滴出行 订单号 12345678 上车 下车 金额 52 元",
|
||
preview_data_url="data:image/png;base64,ZmFrZQ==",
|
||
)
|
||
finally:
|
||
session.close()
|
||
|
||
assert insight.document_type == "taxi_receipt"
|
||
assert insight.classification_source == "llm_vision"
|
||
assert calls[0] == ("vlm",)
|