From c99a423f6a92f88b177030ff2fd7b810cea7a31c Mon Sep 17 00:00:00 2001 From: caoxiaozhu Date: Thu, 14 May 2026 15:42:29 +0000 Subject: [PATCH] =?UTF-8?q?feat(server):=20=E6=89=A9=E5=B1=95=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E6=99=BA=E8=83=BD=E8=AF=86=E5=88=AB=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9EAzure=20Document=20Intelligence?= =?UTF-8?q?=E9=9B=86=E6=88=90=E5=92=8C=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/app/services/document_intelligence.py | 170 ++++++++++++++++-- server/tests/test_document_intelligence.py | 52 ++++++ 2 files changed, 212 insertions(+), 10 deletions(-) diff --git a/server/src/app/services/document_intelligence.py b/server/src/app/services/document_intelligence.py index 795c4c4..7a69596 100644 --- a/server/src/app/services/document_intelligence.py +++ b/server/src/app/services/document_intelligence.py @@ -59,6 +59,7 @@ class LlmDocumentClassification(BaseModel): expense_type: str = Field(default="other") confidence: float = Field(default=0.0, ge=0.0, le=1.0) evidence: list[str] = Field(default_factory=list) + fields: list[DocumentField] = Field(default_factory=list) DEFAULT_RULE = DocumentRule( @@ -177,7 +178,11 @@ DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES} SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",) AMOUNT_PATTERNS = ( - re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[::\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"), + re.compile( + r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)" + r"[::\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)" + ), + re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"), re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"), ) DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)") @@ -278,7 +283,7 @@ class DocumentIntelligenceService: system_prompt = ( "你是企业报销票据识别复核器。" - "你的任务不是 OCR,而是在已有 OCR 文本和票据预览基础上判断票据类型。" + "你的任务不是 OCR,而是在已有 OCR 文本和票据预览基础上判断票据类型,并尽量复核关键字段。" "只输出 JSON 对象,不要输出 Markdown、解释或代码块。" "document_type 只能是:" f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}。" @@ -286,7 +291,10 @@ class DocumentIntelligenceService: "严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。" "如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。" "例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。" - "输出字段:document_type, scene_code, scene_label, expense_type, confidence, evidence。" + "如果能从 OCR 或图片中明确确认字段,可在 fields 中返回。" + "fields 只允许包含 key, label, value,key 只能是 amount, date, merchant_name, invoice_number, " + "invoice_code, trip_no, route。无法确认就不要返回该字段。" + "输出字段:document_type, scene_code, scene_label, expense_type, confidence, evidence, fields。" ) user_prompt = ( "请根据以下票据事实给出最终分类 JSON:\n" @@ -298,7 +306,8 @@ class DocumentIntelligenceService: ' "scene_label": "交通票据",\n' ' "expense_type": "transport",\n' ' "confidence": 0.86,\n' - ' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n' + ' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"],\n' + ' "fields": [{"key": "amount", "label": "金额", "value": "32.5"}]\n' "}" ) @@ -357,6 +366,7 @@ class DocumentIntelligenceService: for item in parsed.evidence if str(item or "").strip() ][:4] + normalized_fields = _normalize_llm_document_fields(parsed.fields) return LlmDocumentClassification( document_type=normalized_type, @@ -365,6 +375,7 @@ class DocumentIntelligenceService: expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type, confidence=float(parsed.confidence), evidence=evidence, + fields=normalized_fields, ) @staticmethod @@ -376,8 +387,28 @@ class DocumentIntelligenceService: has_preview: bool, ) -> DocumentInsight: source, parsed = llm_result + warnings = list(rule_insight.warnings) + merged_fields = rule_insight.fields + if parsed.fields and (has_preview or parsed.confidence >= 0.55): + merged_fields = _merge_document_fields(rule_insight.fields, tuple(parsed.fields)) + if merged_fields != rule_insight.fields: + warnings.append("票据关键信息已结合大模型复核结果修正,建议人工再核对原图。") + if parsed.confidence < 0.55: - return rule_insight + if merged_fields == rule_insight.fields: + return rule_insight + return DocumentInsight( + document_type=rule_insight.document_type, + document_type_label=rule_insight.document_type_label, + scene_code=rule_insight.scene_code, + scene_label=rule_insight.scene_label, + expense_type=rule_insight.expense_type, + fields=merged_fields, + classification_source=rule_insight.classification_source, + classification_confidence=rule_insight.classification_confidence, + evidence=rule_insight.evidence, + warnings=tuple(warnings), + ) should_override = False if parsed.document_type == rule_insight.document_type: @@ -389,10 +420,22 @@ class DocumentIntelligenceService: should_override = parsed.confidence >= threshold if not should_override: - return rule_insight + if merged_fields == rule_insight.fields: + return rule_insight + return DocumentInsight( + document_type=rule_insight.document_type, + document_type_label=rule_insight.document_type_label, + scene_code=rule_insight.scene_code, + scene_label=rule_insight.scene_label, + expense_type=rule_insight.expense_type, + fields=merged_fields, + classification_source=rule_insight.classification_source, + classification_confidence=rule_insight.classification_confidence, + evidence=rule_insight.evidence, + warnings=tuple(warnings), + ) rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE) - warnings = list(rule_insight.warnings) if parsed.document_type != rule_insight.document_type: warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。") @@ -402,7 +445,7 @@ class DocumentIntelligenceService: scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code, scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label, expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type, - fields=fields, + fields=merged_fields, classification_source=source, classification_confidence=max(parsed.confidence, rule_insight.classification_confidence), evidence=tuple(parsed.evidence or rule_insight.evidence), @@ -479,6 +522,115 @@ def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None: return None +def _normalize_llm_document_fields(raw_fields: list[DocumentField] | list[dict[str, Any]]) -> list[DocumentField]: + normalized: list[DocumentField] = [] + seen_keys: set[str] = set() + + for field in raw_fields: + raw_key = str(getattr(field, "key", "") if isinstance(field, DocumentField) else field.get("key") or "").strip() + raw_label = str(getattr(field, "label", "") if isinstance(field, DocumentField) else field.get("label") or "").strip() + raw_value = str(getattr(field, "value", "") if isinstance(field, DocumentField) else field.get("value") or "").strip() + key = _normalize_llm_document_field_key(raw_key, raw_label) + if not key or key in seen_keys: + continue + value = _normalize_llm_document_field_value(key, raw_value) + if not value: + continue + seen_keys.add(key) + normalized.append( + DocumentField( + key=key, + label=_llm_document_field_label(key), + value=value, + ) + ) + + return normalized + + +def _normalize_llm_document_field_key(key: str, label: str) -> str: + compact_key = str(key or "").strip().lower() + compact_label = str(label or "").replace(" ", "").lower() + if compact_key in {"amount", "total_amount", "payment_amount", "paid_amount"} or any( + token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额") + ): + return "amount" + if compact_key in {"date", "time", "issued_at", "invoice_date"} or any( + token in compact_label for token in ("日期", "时间", "开票日期", "发生时间") + ): + return "date" + if compact_key in {"merchant_name", "merchant", "seller_name", "vendor_name"} or any( + token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方") + ): + return "merchant_name" + if compact_key in {"invoice_number", "ticket_number", "order_no", "order_number"} or any( + token in compact_label for token in ("票据号码", "发票号码", "票号", "单号", "订单号") + ): + return "invoice_number" + if compact_key in {"invoice_code"} or "发票代码" in compact_label: + return "invoice_code" + if compact_key in {"trip_no", "flight_no", "train_no"} or any( + token in compact_label for token in ("车次", "航班") + ): + return "trip_no" + if compact_key in {"route", "trip_route"} or any(token in compact_label for token in ("行程", "路线")): + return "route" + return "" + + +def _normalize_llm_document_field_value(key: str, value: str) -> str: + raw_value = str(value or "").strip() + if not raw_value: + return "" + if key == "amount": + amount = _extract_amount(raw_value) + if amount: + return amount + cleaned = raw_value.replace("¥", "").replace("¥", "").replace("元", "").replace(",", ".").strip() + try: + candidate = Decimal(cleaned) + except InvalidOperation: + return "" + if candidate <= Decimal("0.00"): + return "" + text_value = format(candidate.quantize(Decimal("0.01")), "f").rstrip("0").rstrip(".") + return f"{text_value}元" + if key == "date": + return _extract_date(raw_value) or _clean_field_value(raw_value) + if key == "route": + return _extract_route(raw_value) or _clean_field_value( + raw_value.replace("→", "-").replace("至", "-").replace("->", "-") + ) + return _clean_field_value(raw_value) + + +def _llm_document_field_label(key: str) -> str: + return { + "amount": "金额", + "date": "日期", + "merchant_name": "商户", + "invoice_number": "票据号码", + "invoice_code": "发票代码", + "trip_no": "车次/航班", + "route": "行程", + }.get(key, key) + + +def _merge_document_fields( + base_fields: tuple[DocumentField, ...], + override_fields: tuple[DocumentField, ...], +) -> tuple[DocumentField, ...]: + merged = {field.key: field for field in base_fields if field.key and field.value} + order = [field.key for field in base_fields if field.key and field.value] + for field in override_fields: + if not field.key or not field.value: + continue + merged[field.key] = field + if field.key not in order: + order.append(field.key) + return tuple(merged[key] for key in order if key in merged) + + def _extract_document_fields(text: str) -> list[DocumentField]: fields: list[DocumentField] = [] amount = _extract_amount(text) @@ -525,8 +677,6 @@ def _extract_amount(text: str) -> str: continue if best_value is None or candidate > best_value: best_value = candidate - if best_value is not None: - break if best_value is None: return "" diff --git a/server/tests/test_document_intelligence.py b/server/tests/test_document_intelligence.py index 7d57e2d..d95127d 100644 --- a/server/tests/test_document_intelligence.py +++ b/server/tests/test_document_intelligence.py @@ -64,3 +64,55 @@ def test_document_intelligence_service_uses_vlm_result_when_preview_available(mo assert insight.document_type == "taxi_receipt" assert insight.classification_source == "llm_vision" assert calls[0] == ("vlm",) + + +def test_document_intelligence_extracts_larger_decimal_amount_from_multiple_candidates() -> None: + insight = build_document_insight( + filename="taxi-amount.png", + summary="滴滴出行电子行程单", + text="滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678", + ) + + assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields) + + +def test_document_intelligence_service_uses_vlm_fields_to_correct_amount(monkeypatch) -> None: + def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2): + if slot_priority == ("vlm",): + return json.dumps( + { + "document_type": "taxi_receipt", + "scene_code": "transport", + "scene_label": "交通票据", + "expense_type": "transport", + "confidence": 0.89, + "evidence": ["图片主体为滴滴行程单,金额区域显示 13.4 元"], + "fields": [ + {"key": "amount", "label": "金额", "value": "13.4"}, + {"key": "merchant_name", "label": "商户", "value": "滴滴出行"}, + ], + }, + ensure_ascii=False, + ) + return None + + monkeypatch.setattr(RuntimeChatService, "complete", fake_complete) + + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + session = sessionmaker(bind=engine, autoflush=False, autocommit=False)() + try: + insight = DocumentIntelligenceService(session).build_document_insight( + filename="didi-corrected.png", + summary="滴滴出行电子行程单", + text="滴滴出行 支付金额 1 元 订单号 12345678", + preview_data_url="data:image/png;base64,ZmFrZQ==", + ) + finally: + session.close() + + assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields) + assert any("大模型复核结果修正" in warning for warning in insight.warnings)