feat(server): 新增文档智能识别服务,扩展OCR接口支持 Azure Document Intelligence
This commit is contained in:
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.api.deps import CurrentUserContext, get_current_user
|
from app.api.deps import CurrentUserContext, get_current_user, get_db
|
||||||
from app.schemas.common import ErrorResponse
|
from app.schemas.common import ErrorResponse
|
||||||
from app.schemas.ocr import OcrRecognizeBatchRead
|
from app.schemas.ocr import OcrRecognizeBatchRead
|
||||||
from app.services.ocr import OcrService
|
from app.services.ocr import OcrService
|
||||||
@@ -35,6 +36,7 @@ router = APIRouter(prefix="/ocr")
|
|||||||
async def recognize_ocr_documents(
|
async def recognize_ocr_documents(
|
||||||
files: Annotated[list[UploadFile], File(description="待识别的票据图片或 PDF。")],
|
files: Annotated[list[UploadFile], File(description="待识别的票据图片或 PDF。")],
|
||||||
_: Annotated[CurrentUserContext, Depends(get_current_user)],
|
_: Annotated[CurrentUserContext, Depends(get_current_user)],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
) -> OcrRecognizeBatchRead:
|
) -> OcrRecognizeBatchRead:
|
||||||
try:
|
try:
|
||||||
payload = []
|
payload = []
|
||||||
@@ -46,7 +48,7 @@ async def recognize_ocr_documents(
|
|||||||
upload.content_type,
|
upload.content_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return OcrService().recognize_files(payload)
|
return OcrService(db).recognize_files(payload)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
|
|||||||
@@ -10,6 +10,12 @@ class OcrRecognizeLineRead(BaseModel):
|
|||||||
page_index: int | None = Field(default=None, description="页码,从 0 开始。")
|
page_index: int | None = Field(default=None, description="页码,从 0 开始。")
|
||||||
|
|
||||||
|
|
||||||
|
class OcrRecognizeFieldRead(BaseModel):
|
||||||
|
key: str = Field(description="结构化字段键。")
|
||||||
|
label: str = Field(description="结构化字段展示名。")
|
||||||
|
value: str = Field(default="", description="结构化字段值。")
|
||||||
|
|
||||||
|
|
||||||
class OcrRecognizeDocumentRead(BaseModel):
|
class OcrRecognizeDocumentRead(BaseModel):
|
||||||
filename: str = Field(description="原始文件名。")
|
filename: str = Field(description="原始文件名。")
|
||||||
media_type: str = Field(description="文件媒体类型。")
|
media_type: str = Field(description="文件媒体类型。")
|
||||||
@@ -20,6 +26,19 @@ class OcrRecognizeDocumentRead(BaseModel):
|
|||||||
avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="平均识别置信度。")
|
avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="平均识别置信度。")
|
||||||
line_count: int = Field(default=0, ge=0, description="文本行数。")
|
line_count: int = Field(default=0, ge=0, description="文本行数。")
|
||||||
page_count: int = Field(default=1, ge=0, description="识别页数。")
|
page_count: int = Field(default=1, ge=0, description="识别页数。")
|
||||||
|
document_type: str = Field(default="other", description="识别出的票据类型编码。")
|
||||||
|
document_type_label: str = Field(default="其他单据", description="识别出的票据类型名称。")
|
||||||
|
scene_code: str = Field(default="other", description="识别出的票据场景编码。")
|
||||||
|
scene_label: str = Field(default="其他票据", description="识别出的票据场景名称。")
|
||||||
|
classification_source: str = Field(default="rule", description="票据类型判断来源,例如 rule / llm_text / llm_vision。")
|
||||||
|
classification_confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="票据类型判断置信度。")
|
||||||
|
classification_evidence: list[str] = Field(default_factory=list, description="票据类型判断依据摘要。")
|
||||||
|
document_fields: list[OcrRecognizeFieldRead] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="识别出的结构化票据信息。",
|
||||||
|
)
|
||||||
|
preview_kind: str = Field(default="", description="预览类型,PDF 转图后通常为 image。")
|
||||||
|
preview_data_url: str = Field(default="", description="用于前端展示的图片预览 data URL。")
|
||||||
warnings: list[str] = Field(default_factory=list, description="该文件的识别提示或警告。")
|
warnings: list[str] = Field(default_factory=list, description="该文件的识别提示或警告。")
|
||||||
lines: list[OcrRecognizeLineRead] = Field(default_factory=list, description="逐行识别结果。")
|
lines: list[OcrRecognizeLineRead] = Field(default_factory=list, description="逐行识别结果。")
|
||||||
|
|
||||||
|
|||||||
582
server/src/app/services/document_intelligence.py
Normal file
582
server/src/app/services/document_intelligence.py
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from decimal import Decimal, InvalidOperation
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.services.runtime_chat import RuntimeChatService
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class DocumentField:
|
||||||
|
key: str
|
||||||
|
label: str
|
||||||
|
value: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class DocumentInsight:
|
||||||
|
document_type: str
|
||||||
|
document_type_label: str
|
||||||
|
scene_code: str
|
||||||
|
scene_label: str
|
||||||
|
expense_type: str
|
||||||
|
fields: tuple[DocumentField, ...] = ()
|
||||||
|
classification_source: str = "rule"
|
||||||
|
classification_confidence: float = 0.0
|
||||||
|
evidence: tuple[str, ...] = ()
|
||||||
|
warnings: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class DocumentRule:
|
||||||
|
document_type: str
|
||||||
|
document_type_label: str
|
||||||
|
scene_code: str
|
||||||
|
scene_label: str
|
||||||
|
expense_type: str
|
||||||
|
keywords: tuple[str, ...]
|
||||||
|
score_bias: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class RuleMatch:
|
||||||
|
rule: DocumentRule | None
|
||||||
|
confidence: float
|
||||||
|
evidence: tuple[str, ...]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class LlmDocumentClassification(BaseModel):
|
||||||
|
document_type: str = Field(default="other")
|
||||||
|
scene_code: str = Field(default="other")
|
||||||
|
scene_label: str = Field(default="其他票据")
|
||||||
|
expense_type: str = Field(default="other")
|
||||||
|
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||||
|
evidence: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_RULE = DocumentRule(
|
||||||
|
document_type="other",
|
||||||
|
document_type_label="其他单据",
|
||||||
|
scene_code="other",
|
||||||
|
scene_label="其他票据",
|
||||||
|
expense_type="other",
|
||||||
|
keywords=(),
|
||||||
|
score_bias=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCUMENT_RULES: tuple[DocumentRule, ...] = (
|
||||||
|
DocumentRule(
|
||||||
|
document_type="flight_itinerary",
|
||||||
|
document_type_label="机票/航班行程单",
|
||||||
|
scene_code="travel",
|
||||||
|
scene_label="差旅票据",
|
||||||
|
expense_type="travel",
|
||||||
|
keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"),
|
||||||
|
score_bias=0.34,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="train_ticket",
|
||||||
|
document_type_label="火车/高铁票",
|
||||||
|
scene_code="travel",
|
||||||
|
scene_label="差旅票据",
|
||||||
|
expense_type="travel",
|
||||||
|
keywords=("高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座"),
|
||||||
|
score_bias=0.32,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="hotel_invoice",
|
||||||
|
document_type_label="酒店住宿票据",
|
||||||
|
scene_code="hotel",
|
||||||
|
scene_label="住宿票据",
|
||||||
|
expense_type="hotel",
|
||||||
|
keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"),
|
||||||
|
score_bias=0.16,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="taxi_receipt",
|
||||||
|
document_type_label="出租车/网约车票据",
|
||||||
|
scene_code="transport",
|
||||||
|
scene_label="交通票据",
|
||||||
|
expense_type="transport",
|
||||||
|
keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"),
|
||||||
|
score_bias=0.38,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="parking_toll_receipt",
|
||||||
|
document_type_label="停车/通行费票据",
|
||||||
|
scene_code="transport",
|
||||||
|
scene_label="交通票据",
|
||||||
|
expense_type="transport",
|
||||||
|
keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"),
|
||||||
|
score_bias=0.28,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="meal_receipt",
|
||||||
|
document_type_label="餐饮票据",
|
||||||
|
scene_code="meal",
|
||||||
|
scene_label="餐饮票据",
|
||||||
|
expense_type="meal",
|
||||||
|
keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"),
|
||||||
|
score_bias=0.14,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="office_invoice",
|
||||||
|
document_type_label="办公用品票据",
|
||||||
|
scene_code="office",
|
||||||
|
scene_label="办公用品票据",
|
||||||
|
expense_type="office",
|
||||||
|
keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"),
|
||||||
|
score_bias=0.14,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="meeting_invoice",
|
||||||
|
document_type_label="会议/会务票据",
|
||||||
|
scene_code="meeting",
|
||||||
|
scene_label="会务票据",
|
||||||
|
expense_type="meeting",
|
||||||
|
keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"),
|
||||||
|
score_bias=0.12,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="training_invoice",
|
||||||
|
document_type_label="培训票据",
|
||||||
|
scene_code="training",
|
||||||
|
scene_label="培训票据",
|
||||||
|
expense_type="training",
|
||||||
|
keywords=("培训", "课程", "讲师", "教材", "学费", "认证"),
|
||||||
|
score_bias=0.12,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="vat_invoice",
|
||||||
|
document_type_label="增值税发票",
|
||||||
|
scene_code="other",
|
||||||
|
scene_label="通用发票",
|
||||||
|
expense_type="other",
|
||||||
|
keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"),
|
||||||
|
score_bias=-0.08,
|
||||||
|
),
|
||||||
|
DocumentRule(
|
||||||
|
document_type="receipt",
|
||||||
|
document_type_label="一般收据/凭证",
|
||||||
|
scene_code="other",
|
||||||
|
scene_label="其他票据",
|
||||||
|
expense_type="other",
|
||||||
|
keywords=("收据", "凭证", "票据"),
|
||||||
|
score_bias=-0.18,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES}
|
||||||
|
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
|
||||||
|
|
||||||
|
AMOUNT_PATTERNS = (
|
||||||
|
re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[::\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"),
|
||||||
|
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
|
||||||
|
)
|
||||||
|
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
|
||||||
|
INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[::\s]*([A-Za-z0-9-]{6,24})")
|
||||||
|
INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[::\s]*([A-Za-z0-9-]{6,24})")
|
||||||
|
TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[::\s]*([A-Za-z0-9]{2,12})")
|
||||||
|
ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})")
|
||||||
|
MERCHANT_PATTERNS = (
|
||||||
|
re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[::\s]*([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40})"),
|
||||||
|
re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentIntelligenceService:
|
||||||
|
def __init__(self, db: Session | None = None) -> None:
|
||||||
|
self.runtime_chat_service = RuntimeChatService(db) if db is not None else None
|
||||||
|
|
||||||
|
def build_document_insight(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
filename: str = "",
|
||||||
|
summary: str = "",
|
||||||
|
text: str = "",
|
||||||
|
preview_data_url: str = "",
|
||||||
|
) -> DocumentInsight:
|
||||||
|
raw_text = " ".join(
|
||||||
|
[str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()]
|
||||||
|
).strip()
|
||||||
|
compact = re.sub(r"\s+", "", raw_text).lower()
|
||||||
|
rule_match = _match_document_rule(compact)
|
||||||
|
base_rule = rule_match.rule or DEFAULT_RULE
|
||||||
|
fields = tuple(_extract_document_fields(raw_text))
|
||||||
|
rule_insight = DocumentInsight(
|
||||||
|
document_type=base_rule.document_type,
|
||||||
|
document_type_label=base_rule.document_type_label,
|
||||||
|
scene_code=base_rule.scene_code,
|
||||||
|
scene_label=base_rule.scene_label,
|
||||||
|
expense_type=base_rule.expense_type,
|
||||||
|
fields=fields,
|
||||||
|
classification_source="rule",
|
||||||
|
classification_confidence=rule_match.confidence,
|
||||||
|
evidence=rule_match.evidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_result = self._classify_with_model(
|
||||||
|
filename=str(filename or "").strip(),
|
||||||
|
summary=str(summary or "").strip(),
|
||||||
|
text=str(text or "").strip(),
|
||||||
|
preview_data_url=str(preview_data_url or "").strip(),
|
||||||
|
rule_insight=rule_insight,
|
||||||
|
fields=fields,
|
||||||
|
)
|
||||||
|
if llm_result is None:
|
||||||
|
return rule_insight
|
||||||
|
return self._merge_rule_and_model(
|
||||||
|
rule_insight=rule_insight,
|
||||||
|
llm_result=llm_result,
|
||||||
|
fields=fields,
|
||||||
|
has_preview=bool(preview_data_url),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _classify_with_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
filename: str,
|
||||||
|
summary: str,
|
||||||
|
text: str,
|
||||||
|
preview_data_url: str,
|
||||||
|
rule_insight: DocumentInsight,
|
||||||
|
fields: tuple[DocumentField, ...],
|
||||||
|
) -> tuple[str, LlmDocumentClassification] | None:
|
||||||
|
if self.runtime_chat_service is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
trimmed_text = text.strip()
|
||||||
|
if not trimmed_text and not summary.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
facts = {
|
||||||
|
"filename": filename,
|
||||||
|
"summary": summary[:300],
|
||||||
|
"ocr_text_excerpt": trimmed_text[:2000],
|
||||||
|
"rule_candidate": {
|
||||||
|
"document_type": rule_insight.document_type,
|
||||||
|
"document_type_label": rule_insight.document_type_label,
|
||||||
|
"scene_code": rule_insight.scene_code,
|
||||||
|
"scene_label": rule_insight.scene_label,
|
||||||
|
"expense_type": rule_insight.expense_type,
|
||||||
|
"confidence": round(rule_insight.classification_confidence, 2),
|
||||||
|
"evidence": list(rule_insight.evidence),
|
||||||
|
},
|
||||||
|
"extracted_fields": [
|
||||||
|
{"key": field.key, "label": field.label, "value": field.value}
|
||||||
|
for field in fields
|
||||||
|
],
|
||||||
|
"allowed_document_types": list(SUPPORTED_DOCUMENT_TYPES),
|
||||||
|
}
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"你是企业报销票据识别复核器。"
|
||||||
|
"你的任务不是 OCR,而是在已有 OCR 文本和票据预览基础上判断票据类型。"
|
||||||
|
"只输出 JSON 对象,不要输出 Markdown、解释或代码块。"
|
||||||
|
"document_type 只能是:"
|
||||||
|
f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}。"
|
||||||
|
"如果证据不足,返回 other。"
|
||||||
|
"严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。"
|
||||||
|
"如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。"
|
||||||
|
"例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。"
|
||||||
|
"输出字段:document_type, scene_code, scene_label, expense_type, confidence, evidence。"
|
||||||
|
)
|
||||||
|
user_prompt = (
|
||||||
|
"请根据以下票据事实给出最终分类 JSON:\n"
|
||||||
|
f"{json.dumps(facts, ensure_ascii=False, indent=2)}\n\n"
|
||||||
|
"示例输出:\n"
|
||||||
|
"{\n"
|
||||||
|
' "document_type": "taxi_receipt",\n'
|
||||||
|
' "scene_code": "transport",\n'
|
||||||
|
' "scene_label": "交通票据",\n'
|
||||||
|
' "expense_type": "transport",\n'
|
||||||
|
' "confidence": 0.86,\n'
|
||||||
|
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n'
|
||||||
|
"}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if preview_data_url:
|
||||||
|
response_text = self.runtime_chat_service.complete(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": user_prompt},
|
||||||
|
{"type": "image_url", "image_url": {"url": preview_data_url}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
slot_priority=("vlm",),
|
||||||
|
max_tokens=320,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
parsed = self._parse_llm_payload(response_text)
|
||||||
|
if parsed is not None:
|
||||||
|
return "llm_vision", parsed
|
||||||
|
|
||||||
|
response_text = self.runtime_chat_service.complete(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
],
|
||||||
|
slot_priority=("main", "backup"),
|
||||||
|
max_tokens=320,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
parsed = self._parse_llm_payload(response_text)
|
||||||
|
if parsed is not None:
|
||||||
|
return "llm_text", parsed
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None:
|
||||||
|
payload_json = _extract_json_payload(response_text)
|
||||||
|
if payload_json is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = LlmDocumentClassification.model_validate(payload_json)
|
||||||
|
except ValidationError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized_type = str(parsed.document_type or "other").strip().lower() or "other"
|
||||||
|
if normalized_type not in SUPPORTED_DOCUMENT_TYPES:
|
||||||
|
normalized_type = "other"
|
||||||
|
|
||||||
|
base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE)
|
||||||
|
evidence = [
|
||||||
|
str(item or "").strip()
|
||||||
|
for item in parsed.evidence
|
||||||
|
if str(item or "").strip()
|
||||||
|
][:4]
|
||||||
|
|
||||||
|
return LlmDocumentClassification(
|
||||||
|
document_type=normalized_type,
|
||||||
|
scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code,
|
||||||
|
scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label,
|
||||||
|
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
|
||||||
|
confidence=float(parsed.confidence),
|
||||||
|
evidence=evidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_rule_and_model(
|
||||||
|
*,
|
||||||
|
rule_insight: DocumentInsight,
|
||||||
|
llm_result: tuple[str, LlmDocumentClassification],
|
||||||
|
fields: tuple[DocumentField, ...],
|
||||||
|
has_preview: bool,
|
||||||
|
) -> DocumentInsight:
|
||||||
|
source, parsed = llm_result
|
||||||
|
if parsed.confidence < 0.55:
|
||||||
|
return rule_insight
|
||||||
|
|
||||||
|
should_override = False
|
||||||
|
if parsed.document_type == rule_insight.document_type:
|
||||||
|
should_override = True
|
||||||
|
elif rule_insight.document_type == "other" and parsed.document_type != "other":
|
||||||
|
should_override = True
|
||||||
|
elif parsed.document_type != "other":
|
||||||
|
threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12)
|
||||||
|
should_override = parsed.confidence >= threshold
|
||||||
|
|
||||||
|
if not should_override:
|
||||||
|
return rule_insight
|
||||||
|
|
||||||
|
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
|
||||||
|
warnings = list(rule_insight.warnings)
|
||||||
|
if parsed.document_type != rule_insight.document_type:
|
||||||
|
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
|
||||||
|
|
||||||
|
return DocumentInsight(
|
||||||
|
document_type=rule.document_type,
|
||||||
|
document_type_label=rule.document_type_label,
|
||||||
|
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
|
||||||
|
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
|
||||||
|
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
|
||||||
|
fields=fields,
|
||||||
|
classification_source=source,
|
||||||
|
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
|
||||||
|
evidence=tuple(parsed.evidence or rule_insight.evidence),
|
||||||
|
warnings=tuple(warnings),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_document_insight(
|
||||||
|
*,
|
||||||
|
filename: str = "",
|
||||||
|
summary: str = "",
|
||||||
|
text: str = "",
|
||||||
|
preview_data_url: str = "",
|
||||||
|
) -> DocumentInsight:
|
||||||
|
return DocumentIntelligenceService().build_document_insight(
|
||||||
|
filename=filename,
|
||||||
|
summary=summary,
|
||||||
|
text=text,
|
||||||
|
preview_data_url=preview_data_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _match_document_rule(compact_text: str) -> RuleMatch:
|
||||||
|
best_rule = DEFAULT_RULE
|
||||||
|
best_evidence: tuple[str, ...] = ()
|
||||||
|
best_score = 0.0
|
||||||
|
|
||||||
|
for rule in DOCUMENT_RULES:
|
||||||
|
matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text)
|
||||||
|
if not matched:
|
||||||
|
continue
|
||||||
|
score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched)
|
||||||
|
if score > best_score:
|
||||||
|
best_rule = rule
|
||||||
|
best_evidence = matched
|
||||||
|
best_score = score
|
||||||
|
|
||||||
|
if best_score <= 0:
|
||||||
|
return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0)
|
||||||
|
|
||||||
|
confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12)
|
||||||
|
return RuleMatch(
|
||||||
|
rule=best_rule,
|
||||||
|
confidence=round(confidence, 2),
|
||||||
|
evidence=best_evidence[:4],
|
||||||
|
score=best_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
|
||||||
|
if not response_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
|
||||||
|
if not cleaned:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL)
|
||||||
|
candidates = [fenced_match.group(1)] if fenced_match else []
|
||||||
|
candidates.append(cleaned)
|
||||||
|
|
||||||
|
start = cleaned.find("{")
|
||||||
|
end = cleaned.rfind("}")
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
candidates.append(cleaned[start : end + 1])
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(candidate)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_document_fields(text: str) -> list[DocumentField]:
|
||||||
|
fields: list[DocumentField] = []
|
||||||
|
amount = _extract_amount(text)
|
||||||
|
if amount:
|
||||||
|
fields.append(DocumentField(key="amount", label="金额", value=amount))
|
||||||
|
|
||||||
|
date_value = _extract_date(text)
|
||||||
|
if date_value:
|
||||||
|
fields.append(DocumentField(key="date", label="日期", value=date_value))
|
||||||
|
|
||||||
|
merchant = _extract_merchant(text)
|
||||||
|
if merchant:
|
||||||
|
fields.append(DocumentField(key="merchant_name", label="商户", value=merchant))
|
||||||
|
|
||||||
|
invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text)
|
||||||
|
if invoice_number:
|
||||||
|
fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number))
|
||||||
|
|
||||||
|
invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text)
|
||||||
|
if invoice_code:
|
||||||
|
fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code))
|
||||||
|
|
||||||
|
trip_no = _extract_pattern(TRIP_NO_PATTERN, text)
|
||||||
|
if trip_no:
|
||||||
|
fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no))
|
||||||
|
|
||||||
|
route = _extract_route(text)
|
||||||
|
if route:
|
||||||
|
fields.append(DocumentField(key="route", label="行程", value=route))
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_amount(text: str) -> str:
|
||||||
|
best_value: Decimal | None = None
|
||||||
|
for pattern in AMOUNT_PATTERNS:
|
||||||
|
for match in pattern.finditer(text):
|
||||||
|
raw_value = str(match.group(1) or "").replace(",", ".").strip()
|
||||||
|
try:
|
||||||
|
candidate = Decimal(raw_value)
|
||||||
|
except InvalidOperation:
|
||||||
|
continue
|
||||||
|
if candidate <= Decimal("0.00"):
|
||||||
|
continue
|
||||||
|
if best_value is None or candidate > best_value:
|
||||||
|
best_value = candidate
|
||||||
|
if best_value is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if best_value is None:
|
||||||
|
return ""
|
||||||
|
normalized = best_value.quantize(Decimal("0.01"))
|
||||||
|
text_value = format(normalized, "f").rstrip("0").rstrip(".")
|
||||||
|
return f"{text_value}元"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_date(text: str) -> str:
|
||||||
|
match = DATE_PATTERN.search(text)
|
||||||
|
if not match:
|
||||||
|
return ""
|
||||||
|
raw_value = str(match.group(1) or "").strip()
|
||||||
|
normalized = raw_value.replace("年", "-").replace("月", "-").replace("日", "")
|
||||||
|
normalized = normalized.replace("/", "-").replace(".", "-")
|
||||||
|
parts = [part for part in normalized.split("-") if part]
|
||||||
|
if len(parts) != 3:
|
||||||
|
return raw_value
|
||||||
|
year, month, day = parts
|
||||||
|
return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_merchant(text: str) -> str:
|
||||||
|
for pattern in MERCHANT_PATTERNS:
|
||||||
|
match = pattern.search(text)
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
value = _clean_field_value(match.group(1))
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_route(text: str) -> str:
|
||||||
|
match = ROUTE_PATTERN.search(text)
|
||||||
|
if not match:
|
||||||
|
return ""
|
||||||
|
start = _clean_field_value(match.group(1))
|
||||||
|
end = _clean_field_value(match.group(2))
|
||||||
|
if not start or not end or start == end:
|
||||||
|
return ""
|
||||||
|
return f"{start}-{end}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pattern(pattern: re.Pattern[str], text: str) -> str:
|
||||||
|
match = pattern.search(text)
|
||||||
|
if not match:
|
||||||
|
return ""
|
||||||
|
return _clean_field_value(match.group(1))
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_field_value(value: str) -> str:
|
||||||
|
return str(value or "").strip().strip("::,,。.;;")
|
||||||
@@ -1,21 +1,55 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import SERVER_DIR, get_settings
|
from app.core.config import SERVER_DIR, get_settings
|
||||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
|
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
|
||||||
|
from app.services.document_intelligence import DocumentIntelligenceService
|
||||||
|
|
||||||
WORKER_JSON_PREFIX = "__OCR_JSON__="
|
WORKER_JSON_PREFIX = "__OCR_JSON__="
|
||||||
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
|
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class PreparedOcrInput:
|
||||||
|
input_path: Path
|
||||||
|
source_key: str
|
||||||
|
filename: str
|
||||||
|
media_type: str
|
||||||
|
page_index: int | None = None
|
||||||
|
preview_kind: str = ""
|
||||||
|
preview_data_url: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AggregatedOcrDocument:
|
||||||
|
filename: str
|
||||||
|
media_type: str
|
||||||
|
source_key: str
|
||||||
|
engine: str = "paddleocr_mobile"
|
||||||
|
model: str = "PP-OCRv5_mobile"
|
||||||
|
summary_fragments: list[str] = field(default_factory=list)
|
||||||
|
text_fragments: list[str] = field(default_factory=list)
|
||||||
|
score_values: list[float] = field(default_factory=list)
|
||||||
|
warnings: list[str] = field(default_factory=list)
|
||||||
|
lines: list[OcrRecognizeLineRead] = field(default_factory=list)
|
||||||
|
page_count: int = 0
|
||||||
|
preview_kind: str = ""
|
||||||
|
preview_data_url: str = ""
|
||||||
|
|
||||||
|
|
||||||
class OcrService:
|
class OcrService:
|
||||||
def __init__(self) -> None:
|
def __init__(self, db: Session | None = None) -> None:
|
||||||
self.settings = get_settings()
|
self.settings = get_settings()
|
||||||
|
self.document_intelligence_service = DocumentIntelligenceService(db)
|
||||||
|
|
||||||
def recognize_files(
|
def recognize_files(
|
||||||
self,
|
self,
|
||||||
@@ -28,10 +62,11 @@ class OcrService:
|
|||||||
temp_root.mkdir(parents=True, exist_ok=True)
|
temp_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
documents: list[OcrRecognizeDocumentRead] = []
|
documents: list[OcrRecognizeDocumentRead] = []
|
||||||
input_paths: list[Path] = []
|
prepared_inputs: list[PreparedOcrInput] = []
|
||||||
meta_by_path: dict[str, tuple[str, str]] = {}
|
cleanup_paths: list[Path] = []
|
||||||
python_bin = self._resolve_python_bin()
|
python_bin = self._resolve_python_bin()
|
||||||
worker_path = self._resolve_worker_path()
|
worker_path = self._resolve_worker_path()
|
||||||
|
worker_payload: dict = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for filename, content, media_type in files:
|
for filename, content, media_type in files:
|
||||||
@@ -73,17 +108,55 @@ class OcrService:
|
|||||||
|
|
||||||
temp_path = temp_root / f"{uuid4().hex}{suffix}"
|
temp_path = temp_root / f"{uuid4().hex}{suffix}"
|
||||||
temp_path.write_bytes(content)
|
temp_path.write_bytes(content)
|
||||||
input_paths.append(temp_path)
|
cleanup_paths.append(temp_path)
|
||||||
meta_by_path[str(temp_path)] = (normalized_name, resolved_media_type)
|
|
||||||
|
|
||||||
if input_paths:
|
if suffix == ".pdf":
|
||||||
|
try:
|
||||||
|
prepared_inputs.extend(
|
||||||
|
self._prepare_pdf_inputs(
|
||||||
|
pdf_path=temp_path,
|
||||||
|
filename=normalized_name,
|
||||||
|
media_type=resolved_media_type,
|
||||||
|
cleanup_paths=cleanup_paths,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
documents.append(
|
||||||
|
OcrRecognizeDocumentRead(
|
||||||
|
filename=normalized_name,
|
||||||
|
media_type=resolved_media_type,
|
||||||
|
warnings=[str(exc)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prepared_inputs.append(
|
||||||
|
PreparedOcrInput(
|
||||||
|
input_path=temp_path,
|
||||||
|
source_key=uuid4().hex,
|
||||||
|
filename=normalized_name,
|
||||||
|
media_type=resolved_media_type,
|
||||||
|
preview_kind="image" if resolved_media_type.startswith("image/") else "",
|
||||||
|
preview_data_url=(
|
||||||
|
self._build_preview_data_url(temp_path, media_type=resolved_media_type)
|
||||||
|
if resolved_media_type.startswith("image/")
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prepared_inputs:
|
||||||
worker_payload = self._invoke_worker(
|
worker_payload = self._invoke_worker(
|
||||||
python_bin=python_bin,
|
python_bin=python_bin,
|
||||||
worker_path=worker_path,
|
worker_path=worker_path,
|
||||||
input_paths=input_paths,
|
input_paths=[item.input_path for item in prepared_inputs],
|
||||||
|
)
|
||||||
|
documents.extend(
|
||||||
|
self._build_documents(
|
||||||
|
worker_documents=worker_payload.get("documents", []),
|
||||||
|
prepared_inputs=prepared_inputs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for item in worker_payload.get("documents", []):
|
|
||||||
documents.append(self._build_document(item, meta_by_path))
|
|
||||||
|
|
||||||
success_count = sum(
|
success_count = sum(
|
||||||
1
|
1
|
||||||
@@ -92,12 +165,12 @@ class OcrService:
|
|||||||
)
|
)
|
||||||
engine = (
|
engine = (
|
||||||
str(worker_payload.get("engine", "paddleocr_mobile"))
|
str(worker_payload.get("engine", "paddleocr_mobile"))
|
||||||
if input_paths
|
if prepared_inputs
|
||||||
else "paddleocr_mobile"
|
else "paddleocr_mobile"
|
||||||
)
|
)
|
||||||
model = (
|
model = (
|
||||||
str(worker_payload.get("model", "PP-OCRv5_mobile"))
|
str(worker_payload.get("model", "PP-OCRv5_mobile"))
|
||||||
if input_paths
|
if prepared_inputs
|
||||||
else "PP-OCRv5_mobile"
|
else "PP-OCRv5_mobile"
|
||||||
)
|
)
|
||||||
return OcrRecognizeBatchRead(
|
return OcrRecognizeBatchRead(
|
||||||
@@ -108,8 +181,7 @@ class OcrService:
|
|||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
for path in input_paths:
|
self._cleanup_temp_paths(cleanup_paths)
|
||||||
path.unlink(missing_ok=True)
|
|
||||||
|
|
||||||
def _resolve_python_bin(self) -> str:
|
def _resolve_python_bin(self) -> str:
|
||||||
candidates = []
|
candidates = []
|
||||||
@@ -182,40 +254,258 @@ class OcrService:
|
|||||||
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
|
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
def _prepare_pdf_inputs(
|
||||||
def _build_document(
|
self,
|
||||||
payload: dict,
|
*,
|
||||||
meta_by_path: dict[str, tuple[str, str]],
|
pdf_path: Path,
|
||||||
) -> OcrRecognizeDocumentRead:
|
filename: str,
|
||||||
input_path = str(payload.get("input_path") or "")
|
media_type: str,
|
||||||
filename, media_type = meta_by_path.get(
|
cleanup_paths: list[Path],
|
||||||
input_path,
|
) -> list[PreparedOcrInput]:
|
||||||
(Path(input_path).name or "upload.bin", "application/octet-stream"),
|
output_dir = pdf_path.with_suffix("")
|
||||||
)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
lines = [
|
cleanup_paths.append(output_dir)
|
||||||
OcrRecognizeLineRead(
|
|
||||||
text=str(item.get("text", "")),
|
image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir)
|
||||||
score=float(item.get("score", 0.0) or 0.0),
|
if not image_paths:
|
||||||
box=[
|
raise RuntimeError("PDF 转图片后未生成可识别页面。")
|
||||||
[int(point[0]), int(point[1])]
|
|
||||||
for point in item.get("box", [])
|
preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png")
|
||||||
if isinstance(point, list) and len(point) == 2
|
source_key = uuid4().hex
|
||||||
],
|
descriptors: list[PreparedOcrInput] = []
|
||||||
page_index=int(item["page_index"]) if item.get("page_index") is not None else None,
|
for page_index, image_path in enumerate(image_paths):
|
||||||
|
descriptors.append(
|
||||||
|
PreparedOcrInput(
|
||||||
|
input_path=image_path,
|
||||||
|
source_key=source_key,
|
||||||
|
filename=filename,
|
||||||
|
media_type=media_type,
|
||||||
|
page_index=page_index,
|
||||||
|
preview_kind="image" if page_index == 0 else "",
|
||||||
|
preview_data_url=preview_data_url if page_index == 0 else "",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for item in payload.get("lines", [])
|
return descriptors
|
||||||
if isinstance(item, dict)
|
|
||||||
]
|
def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
|
||||||
return OcrRecognizeDocumentRead(
|
prefix = output_dir / "page"
|
||||||
filename=filename,
|
completed = subprocess.run(
|
||||||
media_type=media_type,
|
[
|
||||||
engine=str(payload.get("engine", "paddleocr_mobile")),
|
"pdftoppm",
|
||||||
model=str(payload.get("model", "PP-OCRv5_mobile")),
|
"-png",
|
||||||
text=str(payload.get("text", "")),
|
"-r",
|
||||||
summary=str(payload.get("summary", "")),
|
"160",
|
||||||
avg_score=float(payload.get("avg_score", 0.0) or 0.0),
|
str(pdf_path),
|
||||||
line_count=int(payload.get("line_count", len(lines)) or 0),
|
str(prefix),
|
||||||
page_count=int(payload.get("page_count", 1) or 1),
|
],
|
||||||
warnings=[str(item) for item in payload.get("warnings", [])],
|
capture_output=True,
|
||||||
lines=lines,
|
text=True,
|
||||||
|
timeout=self.settings.ocr_timeout_seconds,
|
||||||
|
check=False,
|
||||||
)
|
)
|
||||||
|
if completed.returncode != 0:
|
||||||
|
detail = (completed.stderr or completed.stdout or "").strip()
|
||||||
|
raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}")
|
||||||
|
|
||||||
|
return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]:
|
||||||
|
suffix = path.stem.rsplit("-", 1)[-1]
|
||||||
|
try:
|
||||||
|
return int(suffix), path.name
|
||||||
|
except ValueError:
|
||||||
|
return 0, path.name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_preview_data_url(path: Path, *, media_type: str) -> str:
|
||||||
|
encoded = base64.b64encode(path.read_bytes()).decode("ascii")
|
||||||
|
return f"data:{media_type};base64,{encoded}"
|
||||||
|
|
||||||
|
def _build_documents(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
worker_documents: list[dict],
|
||||||
|
prepared_inputs: list[PreparedOcrInput],
|
||||||
|
) -> list[OcrRecognizeDocumentRead]:
|
||||||
|
descriptor_by_path = {str(item.input_path): item for item in prepared_inputs}
|
||||||
|
source_order: list[str] = []
|
||||||
|
seen_sources: set[str] = set()
|
||||||
|
for item in prepared_inputs:
|
||||||
|
if item.source_key in seen_sources:
|
||||||
|
continue
|
||||||
|
seen_sources.add(item.source_key)
|
||||||
|
source_order.append(item.source_key)
|
||||||
|
|
||||||
|
aggregated_by_source: dict[str, AggregatedOcrDocument] = {}
|
||||||
|
for payload in worker_documents:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
continue
|
||||||
|
input_path = str(payload.get("input_path") or "")
|
||||||
|
descriptor = descriptor_by_path.get(input_path)
|
||||||
|
if descriptor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
aggregated = aggregated_by_source.get(descriptor.source_key)
|
||||||
|
if aggregated is None:
|
||||||
|
aggregated = AggregatedOcrDocument(
|
||||||
|
filename=descriptor.filename,
|
||||||
|
media_type=descriptor.media_type,
|
||||||
|
source_key=descriptor.source_key,
|
||||||
|
engine=str(payload.get("engine", "paddleocr_mobile")),
|
||||||
|
model=str(payload.get("model", "PP-OCRv5_mobile")),
|
||||||
|
)
|
||||||
|
aggregated_by_source[descriptor.source_key] = aggregated
|
||||||
|
|
||||||
|
aggregated.page_count = max(
|
||||||
|
aggregated.page_count,
|
||||||
|
(descriptor.page_index + 1)
|
||||||
|
if descriptor.page_index is not None
|
||||||
|
else int(payload.get("page_count", 1) or 1),
|
||||||
|
)
|
||||||
|
if descriptor.preview_kind and not aggregated.preview_kind:
|
||||||
|
aggregated.preview_kind = descriptor.preview_kind
|
||||||
|
if descriptor.preview_data_url and not aggregated.preview_data_url:
|
||||||
|
aggregated.preview_data_url = descriptor.preview_data_url
|
||||||
|
|
||||||
|
page_summary = str(payload.get("summary", "") or "").strip()
|
||||||
|
if page_summary:
|
||||||
|
aggregated.summary_fragments.append(page_summary)
|
||||||
|
|
||||||
|
page_text = str(payload.get("text", "") or "").strip()
|
||||||
|
if page_text:
|
||||||
|
aggregated.text_fragments.append(page_text)
|
||||||
|
|
||||||
|
lines = self._build_lines(
|
||||||
|
payload.get("lines", []),
|
||||||
|
page_index_override=descriptor.page_index,
|
||||||
|
)
|
||||||
|
aggregated.lines.extend(lines)
|
||||||
|
aggregated.score_values.extend(line.score for line in lines if line.score > 0)
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
avg_score = float(payload.get("avg_score", 0.0) or 0.0)
|
||||||
|
if avg_score > 0:
|
||||||
|
aggregated.score_values.append(avg_score)
|
||||||
|
|
||||||
|
for warning in payload.get("warnings", []):
|
||||||
|
normalized_warning = str(warning or "").strip()
|
||||||
|
if normalized_warning and normalized_warning not in aggregated.warnings:
|
||||||
|
aggregated.warnings.append(normalized_warning)
|
||||||
|
|
||||||
|
documents: list[OcrRecognizeDocumentRead] = []
|
||||||
|
for source_key in source_order:
|
||||||
|
descriptors = [item for item in prepared_inputs if item.source_key == source_key]
|
||||||
|
if not descriptors:
|
||||||
|
continue
|
||||||
|
aggregated = aggregated_by_source.get(source_key)
|
||||||
|
if aggregated is None:
|
||||||
|
first_descriptor = descriptors[0]
|
||||||
|
documents.append(
|
||||||
|
OcrRecognizeDocumentRead(
|
||||||
|
filename=first_descriptor.filename,
|
||||||
|
media_type=first_descriptor.media_type,
|
||||||
|
page_count=max(1, len(descriptors)),
|
||||||
|
preview_kind=first_descriptor.preview_kind,
|
||||||
|
preview_data_url=first_descriptor.preview_data_url,
|
||||||
|
warnings=["OCR worker 未返回该文件的识别结果。"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
documents.append(self._finalize_document(aggregated))
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_lines(
|
||||||
|
items: list[dict],
|
||||||
|
*,
|
||||||
|
page_index_override: int | None = None,
|
||||||
|
) -> list[OcrRecognizeLineRead]:
|
||||||
|
lines: list[OcrRecognizeLineRead] = []
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
page_index = page_index_override
|
||||||
|
if page_index is None and item.get("page_index") is not None:
|
||||||
|
page_index = int(item["page_index"])
|
||||||
|
lines.append(
|
||||||
|
OcrRecognizeLineRead(
|
||||||
|
text=str(item.get("text", "")),
|
||||||
|
score=float(item.get("score", 0.0) or 0.0),
|
||||||
|
box=[
|
||||||
|
[int(point[0]), int(point[1])]
|
||||||
|
for point in item.get("box", [])
|
||||||
|
if isinstance(point, list) and len(point) == 2
|
||||||
|
],
|
||||||
|
page_index=page_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return lines
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _truncate_summary(parts: list[str]) -> str:
|
||||||
|
summary = ";".join([part for part in parts if part][:3])
|
||||||
|
if len(summary) > 180:
|
||||||
|
return f"{summary[:177]}..."
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead:
|
||||||
|
full_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip()
|
||||||
|
summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments)
|
||||||
|
insight = self.document_intelligence_service.build_document_insight(
|
||||||
|
filename=aggregated.filename,
|
||||||
|
summary=summary,
|
||||||
|
text=full_text,
|
||||||
|
preview_data_url=aggregated.preview_data_url,
|
||||||
|
)
|
||||||
|
warnings = list(aggregated.warnings)
|
||||||
|
for warning in insight.warnings:
|
||||||
|
normalized_warning = str(warning or "").strip()
|
||||||
|
if normalized_warning and normalized_warning not in warnings:
|
||||||
|
warnings.append(normalized_warning)
|
||||||
|
return OcrRecognizeDocumentRead(
|
||||||
|
filename=aggregated.filename,
|
||||||
|
media_type=aggregated.media_type,
|
||||||
|
engine=aggregated.engine,
|
||||||
|
model=aggregated.model,
|
||||||
|
text=full_text,
|
||||||
|
summary=summary,
|
||||||
|
avg_score=(
|
||||||
|
sum(aggregated.score_values) / len(aggregated.score_values)
|
||||||
|
if aggregated.score_values
|
||||||
|
else 0.0
|
||||||
|
),
|
||||||
|
line_count=len(aggregated.lines),
|
||||||
|
page_count=max(1, aggregated.page_count),
|
||||||
|
document_type=insight.document_type,
|
||||||
|
document_type_label=insight.document_type_label,
|
||||||
|
scene_code=insight.scene_code,
|
||||||
|
scene_label=insight.scene_label,
|
||||||
|
classification_source=insight.classification_source,
|
||||||
|
classification_confidence=insight.classification_confidence,
|
||||||
|
classification_evidence=list(insight.evidence),
|
||||||
|
document_fields=[
|
||||||
|
OcrRecognizeFieldRead(
|
||||||
|
key=field.key,
|
||||||
|
label=field.label,
|
||||||
|
value=field.value,
|
||||||
|
)
|
||||||
|
for field in insight.fields
|
||||||
|
],
|
||||||
|
preview_kind=aggregated.preview_kind,
|
||||||
|
preview_data_url=aggregated.preview_data_url,
|
||||||
|
warnings=warnings,
|
||||||
|
lines=sorted(
|
||||||
|
aggregated.lines,
|
||||||
|
key=lambda item: item.page_index if item.page_index is not None else -1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cleanup_temp_paths(paths: list[Path]) -> None:
|
||||||
|
for path in reversed(paths):
|
||||||
|
if path.is_dir():
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
continue
|
||||||
|
path.unlink(missing_ok=True)
|
||||||
|
|||||||
66
server/tests/test_document_intelligence.py
Normal file
66
server/tests/test_document_intelligence.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.services.document_intelligence import DocumentIntelligenceService, build_document_insight
|
||||||
|
from app.services.runtime_chat import RuntimeChatService
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_document_insight_prefers_transport_for_didi_text_with_hotel_noise() -> None:
|
||||||
|
insight = build_document_insight(
|
||||||
|
filename="didi-trip.png",
|
||||||
|
summary="滴滴出行行程单",
|
||||||
|
text="滴滴出行电子发票 订单号 12345678 上车点 深圳湾 下车点 后海 全季酒店 里程 12.4 公里 金额 48 元",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert insight.document_type == "taxi_receipt"
|
||||||
|
assert insight.document_type_label == "出租车/网约车票据"
|
||||||
|
assert insight.scene_code == "transport"
|
||||||
|
assert any(field.label == "金额" and field.value == "48元" for field in insight.fields)
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_intelligence_service_uses_vlm_result_when_preview_available(monkeypatch) -> None:
|
||||||
|
calls: list[tuple[str, ...]] = []
|
||||||
|
|
||||||
|
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
|
||||||
|
calls.append(slot_priority)
|
||||||
|
if slot_priority == ("vlm",):
|
||||||
|
assert isinstance(messages[1]["content"], list)
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"document_type": "taxi_receipt",
|
||||||
|
"scene_code": "transport",
|
||||||
|
"scene_label": "交通票据",
|
||||||
|
"expense_type": "transport",
|
||||||
|
"confidence": 0.91,
|
||||||
|
"evidence": ["图片主体为滴滴行程单,OCR 中出现订单号、上车、下车等字段"],
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite+pysqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
|
||||||
|
try:
|
||||||
|
insight = DocumentIntelligenceService(session).build_document_insight(
|
||||||
|
filename="mixed-noise.png",
|
||||||
|
summary="OCR 混入酒店名称",
|
||||||
|
text="全季酒店 滴滴出行 订单号 12345678 上车 下车 金额 52 元",
|
||||||
|
preview_data_url="data:image/png;base64,ZmFrZQ==",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
assert insight.document_type == "taxi_receipt"
|
||||||
|
assert insight.classification_source == "llm_vision"
|
||||||
|
assert calls[0] == ("vlm",)
|
||||||
@@ -10,7 +10,7 @@ from sqlalchemy.pool import StaticPool
|
|||||||
from app.api.deps import get_db
|
from app.api.deps import get_db
|
||||||
from app.db.base import Base
|
from app.db.base import Base
|
||||||
from app.main import create_app
|
from app.main import create_app
|
||||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
|
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
|
||||||
from app.services.ocr import OcrService
|
from app.services.ocr import OcrService
|
||||||
|
|
||||||
|
|
||||||
@@ -50,14 +50,23 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None:
|
|||||||
OcrRecognizeDocumentRead(
|
OcrRecognizeDocumentRead(
|
||||||
filename="invoice.png",
|
filename="invoice.png",
|
||||||
media_type="image/png",
|
media_type="image/png",
|
||||||
text="发票金额 100 元",
|
text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||||
summary="发票金额 100 元",
|
summary="增值税电子发票,金额 100 元。",
|
||||||
avg_score=0.98,
|
avg_score=0.98,
|
||||||
line_count=1,
|
line_count=1,
|
||||||
page_count=1,
|
page_count=1,
|
||||||
|
document_type="vat_invoice",
|
||||||
|
document_type_label="增值税发票",
|
||||||
|
scene_code="other",
|
||||||
|
scene_label="通用发票",
|
||||||
|
document_fields=[
|
||||||
|
OcrRecognizeFieldRead(key="amount", label="金额", value="100元"),
|
||||||
|
OcrRecognizeFieldRead(key="date", label="日期", value="2026-05-13"),
|
||||||
|
OcrRecognizeFieldRead(key="invoice_number", label="票据号码", value="12345678"),
|
||||||
|
],
|
||||||
lines=[
|
lines=[
|
||||||
OcrRecognizeLineRead(
|
OcrRecognizeLineRead(
|
||||||
text="发票金额 100 元",
|
text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||||
score=0.98,
|
score=0.98,
|
||||||
box=[[1, 2], [10, 2], [10, 8], [1, 8]],
|
box=[[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||||
page_index=0,
|
page_index=0,
|
||||||
@@ -81,4 +90,7 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None:
|
|||||||
assert payload["engine"] == "paddleocr_mobile"
|
assert payload["engine"] == "paddleocr_mobile"
|
||||||
assert payload["success_count"] == 1
|
assert payload["success_count"] == 1
|
||||||
assert payload["documents"][0]["filename"] == "invoice.png"
|
assert payload["documents"][0]["filename"] == "invoice.png"
|
||||||
assert payload["documents"][0]["summary"] == "发票金额 100 元"
|
assert payload["documents"][0]["summary"] == "增值税电子发票,金额 100 元。"
|
||||||
|
assert payload["documents"][0]["document_type"] == "vat_invoice"
|
||||||
|
assert payload["documents"][0]["document_type_label"] == "增值税发票"
|
||||||
|
assert payload["documents"][0]["document_fields"][0]["label"] == "金额"
|
||||||
|
|||||||
@@ -26,15 +26,15 @@ for index, arg in enumerate(sys.argv):
|
|||||||
"input_path": input_path,
|
"input_path": input_path,
|
||||||
"engine": "paddleocr_mobile",
|
"engine": "paddleocr_mobile",
|
||||||
"model": "PP-OCRv5_mobile",
|
"model": "PP-OCRv5_mobile",
|
||||||
"text": "发票金额 100 元",
|
"text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||||
"summary": "发票金额 100 元",
|
"summary": "增值税电子发票,金额 100 元。",
|
||||||
"avg_score": 0.98,
|
"avg_score": 0.98,
|
||||||
"line_count": 1,
|
"line_count": 1,
|
||||||
"page_count": 1,
|
"page_count": 1,
|
||||||
"warnings": [],
|
"warnings": [],
|
||||||
"lines": [
|
"lines": [
|
||||||
{
|
{
|
||||||
"text": "发票金额 100 元",
|
"text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
|
||||||
"score": 0.98,
|
"score": 0.98,
|
||||||
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
|
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||||
"page_index": 0,
|
"page_index": 0,
|
||||||
@@ -74,10 +74,106 @@ print("__OCR_JSON__=" + json.dumps(payload, ensure_ascii=False))
|
|||||||
assert len(result.documents) == 2
|
assert len(result.documents) == 2
|
||||||
|
|
||||||
recognized = next(item for item in result.documents if item.filename == "invoice.png")
|
recognized = next(item for item in result.documents if item.filename == "invoice.png")
|
||||||
assert recognized.summary == "发票金额 100 元"
|
assert recognized.summary == "增值税电子发票,金额 100 元。"
|
||||||
assert recognized.line_count == 1
|
assert recognized.line_count == 1
|
||||||
assert recognized.lines[0].text == "发票金额 100 元"
|
assert recognized.document_type == "vat_invoice"
|
||||||
|
assert recognized.document_type_label == "增值税发票"
|
||||||
|
assert any(field.label == "金额" and field.value == "100元" for field in recognized.document_fields)
|
||||||
|
assert any(field.label == "票据号码" and field.value == "12345678" for field in recognized.document_fields)
|
||||||
|
assert any(field.label == "日期" and field.value == "2026-05-13" for field in recognized.document_fields)
|
||||||
|
assert recognized.lines[0].text == "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13"
|
||||||
|
|
||||||
skipped = next(item for item in result.documents if item.filename == "notes.txt")
|
skipped = next(item for item in result.documents if item.filename == "notes.txt")
|
||||||
assert skipped.line_count == 0
|
assert skipped.line_count == 0
|
||||||
assert skipped.warnings == ["当前仅支持图片和 PDF 文件进行 OCR。"]
|
assert skipped.warnings == ["当前仅支持图片和 PDF 文件进行 OCR。"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ocr_service_converts_pdf_to_images_and_returns_image_preview(
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
def fake_convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
|
||||||
|
first = output_dir / "page-1.png"
|
||||||
|
second = output_dir / "page-2.png"
|
||||||
|
first.write_bytes(b"fake-page-1")
|
||||||
|
second.write_bytes(b"fake-page-2")
|
||||||
|
return [first, second]
|
||||||
|
|
||||||
|
def fake_invoke_worker(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
python_bin: str,
|
||||||
|
worker_path: str,
|
||||||
|
input_paths: list[Path],
|
||||||
|
) -> dict:
|
||||||
|
assert [path.name for path in input_paths] == ["page-1.png", "page-2.png"]
|
||||||
|
return {
|
||||||
|
"engine": "paddleocr_mobile",
|
||||||
|
"model": "PP-OCRv5_mobile",
|
||||||
|
"documents": [
|
||||||
|
{
|
||||||
|
"input_path": str(input_paths[0]),
|
||||||
|
"engine": "paddleocr_mobile",
|
||||||
|
"model": "PP-OCRv5_mobile",
|
||||||
|
"text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元",
|
||||||
|
"summary": "高铁票第一页",
|
||||||
|
"avg_score": 0.97,
|
||||||
|
"line_count": 1,
|
||||||
|
"page_count": 1,
|
||||||
|
"warnings": [],
|
||||||
|
"lines": [
|
||||||
|
{
|
||||||
|
"text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元",
|
||||||
|
"score": 0.97,
|
||||||
|
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_path": str(input_paths[1]),
|
||||||
|
"engine": "paddleocr_mobile",
|
||||||
|
"model": "PP-OCRv5_mobile",
|
||||||
|
"text": "乘车人 张三",
|
||||||
|
"summary": "高铁票第二页",
|
||||||
|
"avg_score": 0.94,
|
||||||
|
"line_count": 1,
|
||||||
|
"page_count": 1,
|
||||||
|
"warnings": [],
|
||||||
|
"lines": [
|
||||||
|
{
|
||||||
|
"text": "乘车人 张三",
|
||||||
|
"score": 0.94,
|
||||||
|
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
|
||||||
|
monkeypatch.setattr(OcrService, "_resolve_python_bin", lambda self: "python")
|
||||||
|
monkeypatch.setattr(OcrService, "_resolve_worker_path", lambda self: "worker.py")
|
||||||
|
monkeypatch.setattr(OcrService, "_convert_pdf_to_images", fake_convert_pdf_to_images)
|
||||||
|
monkeypatch.setattr(OcrService, "_invoke_worker", fake_invoke_worker)
|
||||||
|
get_settings.cache_clear()
|
||||||
|
try:
|
||||||
|
result = OcrService().recognize_files(
|
||||||
|
[
|
||||||
|
("train-ticket.pdf", b"%PDF-1.4 fake", "application/pdf"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
assert result.success_count == 1
|
||||||
|
assert len(result.documents) == 1
|
||||||
|
recognized = result.documents[0]
|
||||||
|
assert recognized.filename == "train-ticket.pdf"
|
||||||
|
assert recognized.page_count == 2
|
||||||
|
assert recognized.preview_kind == "image"
|
||||||
|
assert recognized.preview_data_url.startswith("data:image/png;base64,")
|
||||||
|
assert recognized.document_type == "train_ticket"
|
||||||
|
assert any(field.label == "金额" and field.value == "188元" for field in recognized.document_fields)
|
||||||
|
assert any(field.label == "车次/航班" and field.value == "G1234" for field in recognized.document_fields)
|
||||||
|
assert recognized.lines[0].page_index == 0
|
||||||
|
assert recognized.lines[1].page_index == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user