Files
X-Financial/server/tests/test_document_intelligence.py

119 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.services.document_intelligence import DocumentIntelligenceService, build_document_insight
from app.services.runtime_chat import RuntimeChatService
def test_build_document_insight_prefers_transport_for_didi_text_with_hotel_noise() -> None:
insight = build_document_insight(
filename="didi-trip.png",
summary="滴滴出行行程单",
text="滴滴出行电子发票 订单号 12345678 上车点 深圳湾 下车点 后海 全季酒店 里程 12.4 公里 金额 48 元",
)
assert insight.document_type == "taxi_receipt"
assert insight.document_type_label == "出租车/网约车票据"
assert insight.scene_code == "transport"
assert any(field.label == "金额" and field.value == "48元" for field in insight.fields)
def test_document_intelligence_service_uses_vlm_result_when_preview_available(monkeypatch) -> None:
calls: list[tuple[str, ...]] = []
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
calls.append(slot_priority)
if slot_priority == ("vlm",):
assert isinstance(messages[1]["content"], list)
return json.dumps(
{
"document_type": "taxi_receipt",
"scene_code": "transport",
"scene_label": "交通票据",
"expense_type": "transport",
"confidence": 0.91,
"evidence": ["图片主体为滴滴行程单OCR 中出现订单号、上车、下车等字段"],
},
ensure_ascii=False,
)
return None
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
try:
insight = DocumentIntelligenceService(session).build_document_insight(
filename="mixed-noise.png",
summary="OCR 混入酒店名称",
text="全季酒店 滴滴出行 订单号 12345678 上车 下车 金额 52 元",
preview_data_url="data:image/png;base64,ZmFrZQ==",
)
finally:
session.close()
assert insight.document_type == "taxi_receipt"
assert insight.classification_source == "llm_vision"
assert calls[0] == ("vlm",)
def test_document_intelligence_extracts_larger_decimal_amount_from_multiple_candidates() -> None:
insight = build_document_insight(
filename="taxi-amount.png",
summary="滴滴出行电子行程单",
text="滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678",
)
assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields)
def test_document_intelligence_service_uses_vlm_fields_to_correct_amount(monkeypatch) -> None:
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
if slot_priority == ("vlm",):
return json.dumps(
{
"document_type": "taxi_receipt",
"scene_code": "transport",
"scene_label": "交通票据",
"expense_type": "transport",
"confidence": 0.89,
"evidence": ["图片主体为滴滴行程单,金额区域显示 13.4 元"],
"fields": [
{"key": "amount", "label": "金额", "value": "13.4"},
{"key": "merchant_name", "label": "商户", "value": "滴滴出行"},
],
},
ensure_ascii=False,
)
return None
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
try:
insight = DocumentIntelligenceService(session).build_document_insight(
filename="didi-corrected.png",
summary="滴滴出行电子行程单",
text="滴滴出行 支付金额 1 元 订单号 12345678",
preview_data_url="data:image/png;base64,ZmFrZQ==",
)
finally:
session.close()
assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields)
assert any("大模型复核结果修正" in warning for warning in insight.warnings)