feat(server): 扩展文档智能识别服务,新增Azure Document Intelligence集成和测试用例

This commit is contained in:
caoxiaozhu
2026-05-14 15:42:29 +00:00
parent e21f0d82e9
commit c99a423f6a
2 changed files with 212 additions and 10 deletions

View File

@@ -59,6 +59,7 @@ class LlmDocumentClassification(BaseModel):
expense_type: str = Field(default="other")
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
evidence: list[str] = Field(default_factory=list)
fields: list[DocumentField] = Field(default_factory=list)
DEFAULT_RULE = DocumentRule(
@@ -177,7 +178,11 @@ DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES}
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
AMOUNT_PATTERNS = (
re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[:\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
r"[:\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
),
re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
)
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
@@ -278,7 +283,7 @@ class DocumentIntelligenceService:
system_prompt = (
"你是企业报销票据识别复核器。"
"你的任务不是 OCR而是在已有 OCR 文本和票据预览基础上判断票据类型。"
"你的任务不是 OCR而是在已有 OCR 文本和票据预览基础上判断票据类型,并尽量复核关键字段"
"只输出 JSON 对象,不要输出 Markdown、解释或代码块。"
"document_type 只能是:"
f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}"
@@ -286,7 +291,10 @@ class DocumentIntelligenceService:
"严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。"
"如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。"
"例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。"
"输出字段document_type, scene_code, scene_label, expense_type, confidence, evidence"
"如果能从 OCR 或图片中明确确认字段,可在 fields 中返回"
"fields 只允许包含 key, label, valuekey 只能是 amount, date, merchant_name, invoice_number, "
"invoice_code, trip_no, route。无法确认就不要返回该字段。"
"输出字段document_type, scene_code, scene_label, expense_type, confidence, evidence, fields。"
)
user_prompt = (
"请根据以下票据事实给出最终分类 JSON\n"
@@ -298,7 +306,8 @@ class DocumentIntelligenceService:
' "scene_label": "交通票据",\n'
' "expense_type": "transport",\n'
' "confidence": 0.86,\n'
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n'
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"],\n'
' "fields": [{"key": "amount", "label": "金额", "value": "32.5"}]\n'
"}"
)
@@ -357,6 +366,7 @@ class DocumentIntelligenceService:
for item in parsed.evidence
if str(item or "").strip()
][:4]
normalized_fields = _normalize_llm_document_fields(parsed.fields)
return LlmDocumentClassification(
document_type=normalized_type,
@@ -365,6 +375,7 @@ class DocumentIntelligenceService:
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
confidence=float(parsed.confidence),
evidence=evidence,
fields=normalized_fields,
)
@staticmethod
@@ -376,8 +387,28 @@ class DocumentIntelligenceService:
has_preview: bool,
) -> DocumentInsight:
source, parsed = llm_result
warnings = list(rule_insight.warnings)
merged_fields = rule_insight.fields
if parsed.fields and (has_preview or parsed.confidence >= 0.55):
merged_fields = _merge_document_fields(rule_insight.fields, tuple(parsed.fields))
if merged_fields != rule_insight.fields:
warnings.append("票据关键信息已结合大模型复核结果修正,建议人工再核对原图。")
if parsed.confidence < 0.55:
if merged_fields == rule_insight.fields:
return rule_insight
return DocumentInsight(
document_type=rule_insight.document_type,
document_type_label=rule_insight.document_type_label,
scene_code=rule_insight.scene_code,
scene_label=rule_insight.scene_label,
expense_type=rule_insight.expense_type,
fields=merged_fields,
classification_source=rule_insight.classification_source,
classification_confidence=rule_insight.classification_confidence,
evidence=rule_insight.evidence,
warnings=tuple(warnings),
)
should_override = False
if parsed.document_type == rule_insight.document_type:
@@ -389,10 +420,22 @@ class DocumentIntelligenceService:
should_override = parsed.confidence >= threshold
if not should_override:
if merged_fields == rule_insight.fields:
return rule_insight
return DocumentInsight(
document_type=rule_insight.document_type,
document_type_label=rule_insight.document_type_label,
scene_code=rule_insight.scene_code,
scene_label=rule_insight.scene_label,
expense_type=rule_insight.expense_type,
fields=merged_fields,
classification_source=rule_insight.classification_source,
classification_confidence=rule_insight.classification_confidence,
evidence=rule_insight.evidence,
warnings=tuple(warnings),
)
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
warnings = list(rule_insight.warnings)
if parsed.document_type != rule_insight.document_type:
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
@@ -402,7 +445,7 @@ class DocumentIntelligenceService:
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
fields=fields,
fields=merged_fields,
classification_source=source,
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
evidence=tuple(parsed.evidence or rule_insight.evidence),
@@ -479,6 +522,115 @@ def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
return None
def _normalize_llm_document_fields(raw_fields: list[DocumentField] | list[dict[str, Any]]) -> list[DocumentField]:
normalized: list[DocumentField] = []
seen_keys: set[str] = set()
for field in raw_fields:
raw_key = str(getattr(field, "key", "") if isinstance(field, DocumentField) else field.get("key") or "").strip()
raw_label = str(getattr(field, "label", "") if isinstance(field, DocumentField) else field.get("label") or "").strip()
raw_value = str(getattr(field, "value", "") if isinstance(field, DocumentField) else field.get("value") or "").strip()
key = _normalize_llm_document_field_key(raw_key, raw_label)
if not key or key in seen_keys:
continue
value = _normalize_llm_document_field_value(key, raw_value)
if not value:
continue
seen_keys.add(key)
normalized.append(
DocumentField(
key=key,
label=_llm_document_field_label(key),
value=value,
)
)
return normalized
def _normalize_llm_document_field_key(key: str, label: str) -> str:
compact_key = str(key or "").strip().lower()
compact_label = str(label or "").replace(" ", "").lower()
if compact_key in {"amount", "total_amount", "payment_amount", "paid_amount"} or any(
token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
):
return "amount"
if compact_key in {"date", "time", "issued_at", "invoice_date"} or any(
token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")
):
return "date"
if compact_key in {"merchant_name", "merchant", "seller_name", "vendor_name"} or any(
token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方")
):
return "merchant_name"
if compact_key in {"invoice_number", "ticket_number", "order_no", "order_number"} or any(
token in compact_label for token in ("票据号码", "发票号码", "票号", "单号", "订单号")
):
return "invoice_number"
if compact_key in {"invoice_code"} or "发票代码" in compact_label:
return "invoice_code"
if compact_key in {"trip_no", "flight_no", "train_no"} or any(
token in compact_label for token in ("车次", "航班")
):
return "trip_no"
if compact_key in {"route", "trip_route"} or any(token in compact_label for token in ("行程", "路线")):
return "route"
return ""
def _normalize_llm_document_field_value(key: str, value: str) -> str:
raw_value = str(value or "").strip()
if not raw_value:
return ""
if key == "amount":
amount = _extract_amount(raw_value)
if amount:
return amount
cleaned = raw_value.replace("", "").replace("¥", "").replace("", "").replace(",", ".").strip()
try:
candidate = Decimal(cleaned)
except InvalidOperation:
return ""
if candidate <= Decimal("0.00"):
return ""
text_value = format(candidate.quantize(Decimal("0.01")), "f").rstrip("0").rstrip(".")
return f"{text_value}"
if key == "date":
return _extract_date(raw_value) or _clean_field_value(raw_value)
if key == "route":
return _extract_route(raw_value) or _clean_field_value(
raw_value.replace("", "-").replace("", "-").replace("->", "-")
)
return _clean_field_value(raw_value)
def _llm_document_field_label(key: str) -> str:
return {
"amount": "金额",
"date": "日期",
"merchant_name": "商户",
"invoice_number": "票据号码",
"invoice_code": "发票代码",
"trip_no": "车次/航班",
"route": "行程",
}.get(key, key)
def _merge_document_fields(
base_fields: tuple[DocumentField, ...],
override_fields: tuple[DocumentField, ...],
) -> tuple[DocumentField, ...]:
merged = {field.key: field for field in base_fields if field.key and field.value}
order = [field.key for field in base_fields if field.key and field.value]
for field in override_fields:
if not field.key or not field.value:
continue
merged[field.key] = field
if field.key not in order:
order.append(field.key)
return tuple(merged[key] for key in order if key in merged)
def _extract_document_fields(text: str) -> list[DocumentField]:
fields: list[DocumentField] = []
amount = _extract_amount(text)
@@ -525,8 +677,6 @@ def _extract_amount(text: str) -> str:
continue
if best_value is None or candidate > best_value:
best_value = candidate
if best_value is not None:
break
if best_value is None:
return ""

View File

@@ -64,3 +64,55 @@ def test_document_intelligence_service_uses_vlm_result_when_preview_available(mo
assert insight.document_type == "taxi_receipt"
assert insight.classification_source == "llm_vision"
assert calls[0] == ("vlm",)
def test_document_intelligence_extracts_larger_decimal_amount_from_multiple_candidates() -> None:
insight = build_document_insight(
filename="taxi-amount.png",
summary="滴滴出行电子行程单",
text="滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678",
)
assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields)
def test_document_intelligence_service_uses_vlm_fields_to_correct_amount(monkeypatch) -> None:
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
if slot_priority == ("vlm",):
return json.dumps(
{
"document_type": "taxi_receipt",
"scene_code": "transport",
"scene_label": "交通票据",
"expense_type": "transport",
"confidence": 0.89,
"evidence": ["图片主体为滴滴行程单,金额区域显示 13.4 元"],
"fields": [
{"key": "amount", "label": "金额", "value": "13.4"},
{"key": "merchant_name", "label": "商户", "value": "滴滴出行"},
],
},
ensure_ascii=False,
)
return None
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
try:
insight = DocumentIntelligenceService(session).build_document_insight(
filename="didi-corrected.png",
summary="滴滴出行电子行程单",
text="滴滴出行 支付金额 1 元 订单号 12345678",
preview_data_url="data:image/png;base64,ZmFrZQ==",
)
finally:
session.close()
assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields)
assert any("大模型复核结果修正" in warning for warning in insight.warnings)