feat(server): 扩展文档智能识别服务,新增Azure Document Intelligence集成和测试用例
This commit is contained in:
@@ -59,6 +59,7 @@ class LlmDocumentClassification(BaseModel):
|
|||||||
expense_type: str = Field(default="other")
|
expense_type: str = Field(default="other")
|
||||||
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
|
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||||
evidence: list[str] = Field(default_factory=list)
|
evidence: list[str] = Field(default_factory=list)
|
||||||
|
fields: list[DocumentField] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_RULE = DocumentRule(
|
DEFAULT_RULE = DocumentRule(
|
||||||
@@ -177,7 +178,11 @@ DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES}
|
|||||||
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
|
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
|
||||||
|
|
||||||
AMOUNT_PATTERNS = (
|
AMOUNT_PATTERNS = (
|
||||||
re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[::\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"),
|
re.compile(
|
||||||
|
r"(?:价税合计|合计金额|费用合计|订单(?:总)?金额|支付(?:金额)?|实付(?:金额)?|实收(?:金额)?|总(?:额|计|价)|票价|金额|车费|消费金额)"
|
||||||
|
r"[::\s¥¥人民币]*([0-9]+(?:[.,][0-9]{1,2})?)"
|
||||||
|
),
|
||||||
|
re.compile(r"[¥¥]\s*([0-9]+(?:[.,][0-9]{1,2})?)"),
|
||||||
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
|
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
|
||||||
)
|
)
|
||||||
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
|
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
|
||||||
@@ -278,7 +283,7 @@ class DocumentIntelligenceService:
|
|||||||
|
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
"你是企业报销票据识别复核器。"
|
"你是企业报销票据识别复核器。"
|
||||||
"你的任务不是 OCR,而是在已有 OCR 文本和票据预览基础上判断票据类型。"
|
"你的任务不是 OCR,而是在已有 OCR 文本和票据预览基础上判断票据类型,并尽量复核关键字段。"
|
||||||
"只输出 JSON 对象,不要输出 Markdown、解释或代码块。"
|
"只输出 JSON 对象,不要输出 Markdown、解释或代码块。"
|
||||||
"document_type 只能是:"
|
"document_type 只能是:"
|
||||||
f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}。"
|
f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}。"
|
||||||
@@ -286,7 +291,10 @@ class DocumentIntelligenceService:
|
|||||||
"严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。"
|
"严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。"
|
||||||
"如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。"
|
"如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。"
|
||||||
"例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。"
|
"例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。"
|
||||||
"输出字段:document_type, scene_code, scene_label, expense_type, confidence, evidence。"
|
"如果能从 OCR 或图片中明确确认字段,可在 fields 中返回。"
|
||||||
|
"fields 只允许包含 key, label, value,key 只能是 amount, date, merchant_name, invoice_number, "
|
||||||
|
"invoice_code, trip_no, route。无法确认就不要返回该字段。"
|
||||||
|
"输出字段:document_type, scene_code, scene_label, expense_type, confidence, evidence, fields。"
|
||||||
)
|
)
|
||||||
user_prompt = (
|
user_prompt = (
|
||||||
"请根据以下票据事实给出最终分类 JSON:\n"
|
"请根据以下票据事实给出最终分类 JSON:\n"
|
||||||
@@ -298,7 +306,8 @@ class DocumentIntelligenceService:
|
|||||||
' "scene_label": "交通票据",\n'
|
' "scene_label": "交通票据",\n'
|
||||||
' "expense_type": "transport",\n'
|
' "expense_type": "transport",\n'
|
||||||
' "confidence": 0.86,\n'
|
' "confidence": 0.86,\n'
|
||||||
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n'
|
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"],\n'
|
||||||
|
' "fields": [{"key": "amount", "label": "金额", "value": "32.5"}]\n'
|
||||||
"}"
|
"}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -357,6 +366,7 @@ class DocumentIntelligenceService:
|
|||||||
for item in parsed.evidence
|
for item in parsed.evidence
|
||||||
if str(item or "").strip()
|
if str(item or "").strip()
|
||||||
][:4]
|
][:4]
|
||||||
|
normalized_fields = _normalize_llm_document_fields(parsed.fields)
|
||||||
|
|
||||||
return LlmDocumentClassification(
|
return LlmDocumentClassification(
|
||||||
document_type=normalized_type,
|
document_type=normalized_type,
|
||||||
@@ -365,6 +375,7 @@ class DocumentIntelligenceService:
|
|||||||
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
|
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
|
||||||
confidence=float(parsed.confidence),
|
confidence=float(parsed.confidence),
|
||||||
evidence=evidence,
|
evidence=evidence,
|
||||||
|
fields=normalized_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -376,8 +387,28 @@ class DocumentIntelligenceService:
|
|||||||
has_preview: bool,
|
has_preview: bool,
|
||||||
) -> DocumentInsight:
|
) -> DocumentInsight:
|
||||||
source, parsed = llm_result
|
source, parsed = llm_result
|
||||||
|
warnings = list(rule_insight.warnings)
|
||||||
|
merged_fields = rule_insight.fields
|
||||||
|
if parsed.fields and (has_preview or parsed.confidence >= 0.55):
|
||||||
|
merged_fields = _merge_document_fields(rule_insight.fields, tuple(parsed.fields))
|
||||||
|
if merged_fields != rule_insight.fields:
|
||||||
|
warnings.append("票据关键信息已结合大模型复核结果修正,建议人工再核对原图。")
|
||||||
|
|
||||||
if parsed.confidence < 0.55:
|
if parsed.confidence < 0.55:
|
||||||
return rule_insight
|
if merged_fields == rule_insight.fields:
|
||||||
|
return rule_insight
|
||||||
|
return DocumentInsight(
|
||||||
|
document_type=rule_insight.document_type,
|
||||||
|
document_type_label=rule_insight.document_type_label,
|
||||||
|
scene_code=rule_insight.scene_code,
|
||||||
|
scene_label=rule_insight.scene_label,
|
||||||
|
expense_type=rule_insight.expense_type,
|
||||||
|
fields=merged_fields,
|
||||||
|
classification_source=rule_insight.classification_source,
|
||||||
|
classification_confidence=rule_insight.classification_confidence,
|
||||||
|
evidence=rule_insight.evidence,
|
||||||
|
warnings=tuple(warnings),
|
||||||
|
)
|
||||||
|
|
||||||
should_override = False
|
should_override = False
|
||||||
if parsed.document_type == rule_insight.document_type:
|
if parsed.document_type == rule_insight.document_type:
|
||||||
@@ -389,10 +420,22 @@ class DocumentIntelligenceService:
|
|||||||
should_override = parsed.confidence >= threshold
|
should_override = parsed.confidence >= threshold
|
||||||
|
|
||||||
if not should_override:
|
if not should_override:
|
||||||
return rule_insight
|
if merged_fields == rule_insight.fields:
|
||||||
|
return rule_insight
|
||||||
|
return DocumentInsight(
|
||||||
|
document_type=rule_insight.document_type,
|
||||||
|
document_type_label=rule_insight.document_type_label,
|
||||||
|
scene_code=rule_insight.scene_code,
|
||||||
|
scene_label=rule_insight.scene_label,
|
||||||
|
expense_type=rule_insight.expense_type,
|
||||||
|
fields=merged_fields,
|
||||||
|
classification_source=rule_insight.classification_source,
|
||||||
|
classification_confidence=rule_insight.classification_confidence,
|
||||||
|
evidence=rule_insight.evidence,
|
||||||
|
warnings=tuple(warnings),
|
||||||
|
)
|
||||||
|
|
||||||
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
|
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
|
||||||
warnings = list(rule_insight.warnings)
|
|
||||||
if parsed.document_type != rule_insight.document_type:
|
if parsed.document_type != rule_insight.document_type:
|
||||||
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
|
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
|
||||||
|
|
||||||
@@ -402,7 +445,7 @@ class DocumentIntelligenceService:
|
|||||||
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
|
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
|
||||||
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
|
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
|
||||||
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
|
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
|
||||||
fields=fields,
|
fields=merged_fields,
|
||||||
classification_source=source,
|
classification_source=source,
|
||||||
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
|
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
|
||||||
evidence=tuple(parsed.evidence or rule_insight.evidence),
|
evidence=tuple(parsed.evidence or rule_insight.evidence),
|
||||||
@@ -479,6 +522,115 @@ def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_llm_document_fields(raw_fields: list[DocumentField] | list[dict[str, Any]]) -> list[DocumentField]:
|
||||||
|
normalized: list[DocumentField] = []
|
||||||
|
seen_keys: set[str] = set()
|
||||||
|
|
||||||
|
for field in raw_fields:
|
||||||
|
raw_key = str(getattr(field, "key", "") if isinstance(field, DocumentField) else field.get("key") or "").strip()
|
||||||
|
raw_label = str(getattr(field, "label", "") if isinstance(field, DocumentField) else field.get("label") or "").strip()
|
||||||
|
raw_value = str(getattr(field, "value", "") if isinstance(field, DocumentField) else field.get("value") or "").strip()
|
||||||
|
key = _normalize_llm_document_field_key(raw_key, raw_label)
|
||||||
|
if not key or key in seen_keys:
|
||||||
|
continue
|
||||||
|
value = _normalize_llm_document_field_value(key, raw_value)
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
seen_keys.add(key)
|
||||||
|
normalized.append(
|
||||||
|
DocumentField(
|
||||||
|
key=key,
|
||||||
|
label=_llm_document_field_label(key),
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_llm_document_field_key(key: str, label: str) -> str:
|
||||||
|
compact_key = str(key or "").strip().lower()
|
||||||
|
compact_label = str(label or "").replace(" ", "").lower()
|
||||||
|
if compact_key in {"amount", "total_amount", "payment_amount", "paid_amount"} or any(
|
||||||
|
token in compact_label for token in ("金额", "价税合计", "合计", "总额", "总计", "票价", "支付金额", "实付金额", "实收金额")
|
||||||
|
):
|
||||||
|
return "amount"
|
||||||
|
if compact_key in {"date", "time", "issued_at", "invoice_date"} or any(
|
||||||
|
token in compact_label for token in ("日期", "时间", "开票日期", "发生时间")
|
||||||
|
):
|
||||||
|
return "date"
|
||||||
|
if compact_key in {"merchant_name", "merchant", "seller_name", "vendor_name"} or any(
|
||||||
|
token in compact_label for token in ("商户", "酒店", "销售方", "开票方", "收款方")
|
||||||
|
):
|
||||||
|
return "merchant_name"
|
||||||
|
if compact_key in {"invoice_number", "ticket_number", "order_no", "order_number"} or any(
|
||||||
|
token in compact_label for token in ("票据号码", "发票号码", "票号", "单号", "订单号")
|
||||||
|
):
|
||||||
|
return "invoice_number"
|
||||||
|
if compact_key in {"invoice_code"} or "发票代码" in compact_label:
|
||||||
|
return "invoice_code"
|
||||||
|
if compact_key in {"trip_no", "flight_no", "train_no"} or any(
|
||||||
|
token in compact_label for token in ("车次", "航班")
|
||||||
|
):
|
||||||
|
return "trip_no"
|
||||||
|
if compact_key in {"route", "trip_route"} or any(token in compact_label for token in ("行程", "路线")):
|
||||||
|
return "route"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_llm_document_field_value(key: str, value: str) -> str:
|
||||||
|
raw_value = str(value or "").strip()
|
||||||
|
if not raw_value:
|
||||||
|
return ""
|
||||||
|
if key == "amount":
|
||||||
|
amount = _extract_amount(raw_value)
|
||||||
|
if amount:
|
||||||
|
return amount
|
||||||
|
cleaned = raw_value.replace("¥", "").replace("¥", "").replace("元", "").replace(",", ".").strip()
|
||||||
|
try:
|
||||||
|
candidate = Decimal(cleaned)
|
||||||
|
except InvalidOperation:
|
||||||
|
return ""
|
||||||
|
if candidate <= Decimal("0.00"):
|
||||||
|
return ""
|
||||||
|
text_value = format(candidate.quantize(Decimal("0.01")), "f").rstrip("0").rstrip(".")
|
||||||
|
return f"{text_value}元"
|
||||||
|
if key == "date":
|
||||||
|
return _extract_date(raw_value) or _clean_field_value(raw_value)
|
||||||
|
if key == "route":
|
||||||
|
return _extract_route(raw_value) or _clean_field_value(
|
||||||
|
raw_value.replace("→", "-").replace("至", "-").replace("->", "-")
|
||||||
|
)
|
||||||
|
return _clean_field_value(raw_value)
|
||||||
|
|
||||||
|
|
||||||
|
def _llm_document_field_label(key: str) -> str:
|
||||||
|
return {
|
||||||
|
"amount": "金额",
|
||||||
|
"date": "日期",
|
||||||
|
"merchant_name": "商户",
|
||||||
|
"invoice_number": "票据号码",
|
||||||
|
"invoice_code": "发票代码",
|
||||||
|
"trip_no": "车次/航班",
|
||||||
|
"route": "行程",
|
||||||
|
}.get(key, key)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_document_fields(
|
||||||
|
base_fields: tuple[DocumentField, ...],
|
||||||
|
override_fields: tuple[DocumentField, ...],
|
||||||
|
) -> tuple[DocumentField, ...]:
|
||||||
|
merged = {field.key: field for field in base_fields if field.key and field.value}
|
||||||
|
order = [field.key for field in base_fields if field.key and field.value]
|
||||||
|
for field in override_fields:
|
||||||
|
if not field.key or not field.value:
|
||||||
|
continue
|
||||||
|
merged[field.key] = field
|
||||||
|
if field.key not in order:
|
||||||
|
order.append(field.key)
|
||||||
|
return tuple(merged[key] for key in order if key in merged)
|
||||||
|
|
||||||
|
|
||||||
def _extract_document_fields(text: str) -> list[DocumentField]:
|
def _extract_document_fields(text: str) -> list[DocumentField]:
|
||||||
fields: list[DocumentField] = []
|
fields: list[DocumentField] = []
|
||||||
amount = _extract_amount(text)
|
amount = _extract_amount(text)
|
||||||
@@ -525,8 +677,6 @@ def _extract_amount(text: str) -> str:
|
|||||||
continue
|
continue
|
||||||
if best_value is None or candidate > best_value:
|
if best_value is None or candidate > best_value:
|
||||||
best_value = candidate
|
best_value = candidate
|
||||||
if best_value is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
if best_value is None:
|
if best_value is None:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -64,3 +64,55 @@ def test_document_intelligence_service_uses_vlm_result_when_preview_available(mo
|
|||||||
assert insight.document_type == "taxi_receipt"
|
assert insight.document_type == "taxi_receipt"
|
||||||
assert insight.classification_source == "llm_vision"
|
assert insight.classification_source == "llm_vision"
|
||||||
assert calls[0] == ("vlm",)
|
assert calls[0] == ("vlm",)
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_intelligence_extracts_larger_decimal_amount_from_multiple_candidates() -> None:
|
||||||
|
insight = build_document_insight(
|
||||||
|
filename="taxi-amount.png",
|
||||||
|
summary="滴滴出行电子行程单",
|
||||||
|
text="滴滴出行 支付金额 1 元,实付 13.4 元,订单号 12345678",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields)
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_intelligence_service_uses_vlm_fields_to_correct_amount(monkeypatch) -> None:
|
||||||
|
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
|
||||||
|
if slot_priority == ("vlm",):
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"document_type": "taxi_receipt",
|
||||||
|
"scene_code": "transport",
|
||||||
|
"scene_label": "交通票据",
|
||||||
|
"expense_type": "transport",
|
||||||
|
"confidence": 0.89,
|
||||||
|
"evidence": ["图片主体为滴滴行程单,金额区域显示 13.4 元"],
|
||||||
|
"fields": [
|
||||||
|
{"key": "amount", "label": "金额", "value": "13.4"},
|
||||||
|
{"key": "merchant_name", "label": "商户", "value": "滴滴出行"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
|
||||||
|
try:
|
||||||
|
insight = DocumentIntelligenceService(session).build_document_insight(
|
||||||
|
filename="didi-corrected.png",
|
||||||
|
summary="滴滴出行电子行程单",
|
||||||
|
text="滴滴出行 支付金额 1 元 订单号 12345678",
|
||||||
|
preview_data_url="data:image/png;base64,ZmFrZQ==",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
assert any(field.label == "金额" and field.value == "13.4元" for field in insight.fields)
|
||||||
|
assert any("大模型复核结果修正" in warning for warning in insight.warnings)
|
||||||
|
|||||||
Reference in New Issue
Block a user