feat(server): 新增文档智能识别服务,扩展OCR接口支持 Azure Document Intelligence

This commit is contained in:
caoxiaozhu
2026-05-14 09:32:15 +00:00
parent 8adeefe4a9
commit 8b39f48dec
7 changed files with 1128 additions and 61 deletions

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from sqlalchemy.orm import Session
from app.api.deps import CurrentUserContext, get_current_user
from app.api.deps import CurrentUserContext, get_current_user, get_db
from app.schemas.common import ErrorResponse
from app.schemas.ocr import OcrRecognizeBatchRead
from app.services.ocr import OcrService
@@ -35,6 +36,7 @@ router = APIRouter(prefix="/ocr")
async def recognize_ocr_documents(
files: Annotated[list[UploadFile], File(description="待识别的票据图片或 PDF。")],
_: Annotated[CurrentUserContext, Depends(get_current_user)],
db: Annotated[Session, Depends(get_db)],
) -> OcrRecognizeBatchRead:
try:
payload = []
@@ -46,7 +48,7 @@ async def recognize_ocr_documents(
upload.content_type,
)
)
return OcrService().recognize_files(payload)
return OcrService(db).recognize_files(payload)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except RuntimeError as exc:

View File

@@ -10,6 +10,12 @@ class OcrRecognizeLineRead(BaseModel):
page_index: int | None = Field(default=None, description="页码,从 0 开始。")
class OcrRecognizeFieldRead(BaseModel):
key: str = Field(description="结构化字段键。")
label: str = Field(description="结构化字段展示名。")
value: str = Field(default="", description="结构化字段值。")
class OcrRecognizeDocumentRead(BaseModel):
filename: str = Field(description="原始文件名。")
media_type: str = Field(description="文件媒体类型。")
@@ -20,6 +26,19 @@ class OcrRecognizeDocumentRead(BaseModel):
avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="平均识别置信度。")
line_count: int = Field(default=0, ge=0, description="文本行数。")
page_count: int = Field(default=1, ge=0, description="识别页数。")
document_type: str = Field(default="other", description="识别出的票据类型编码。")
document_type_label: str = Field(default="其他单据", description="识别出的票据类型名称。")
scene_code: str = Field(default="other", description="识别出的票据场景编码。")
scene_label: str = Field(default="其他票据", description="识别出的票据场景名称。")
classification_source: str = Field(default="rule", description="票据类型判断来源,例如 rule / llm_text / llm_vision。")
classification_confidence: float = Field(default=0.0, ge=0.0, le=1.0, description="票据类型判断置信度。")
classification_evidence: list[str] = Field(default_factory=list, description="票据类型判断依据摘要。")
document_fields: list[OcrRecognizeFieldRead] = Field(
default_factory=list,
description="识别出的结构化票据信息。",
)
preview_kind: str = Field(default="", description="预览类型PDF 转图后通常为 image。")
preview_data_url: str = Field(default="", description="用于前端展示的图片预览 data URL。")
warnings: list[str] = Field(default_factory=list, description="该文件的识别提示或警告。")
lines: list[OcrRecognizeLineRead] = Field(default_factory=list, description="逐行识别结果。")

View File

@@ -0,0 +1,582 @@
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from app.services.runtime_chat import RuntimeChatService
@dataclass(frozen=True, slots=True)
class DocumentField:
key: str
label: str
value: str
@dataclass(frozen=True, slots=True)
class DocumentInsight:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
fields: tuple[DocumentField, ...] = ()
classification_source: str = "rule"
classification_confidence: float = 0.0
evidence: tuple[str, ...] = ()
warnings: tuple[str, ...] = ()
@dataclass(frozen=True, slots=True)
class DocumentRule:
document_type: str
document_type_label: str
scene_code: str
scene_label: str
expense_type: str
keywords: tuple[str, ...]
score_bias: float = 0.0
@dataclass(frozen=True, slots=True)
class RuleMatch:
rule: DocumentRule | None
confidence: float
evidence: tuple[str, ...]
score: float
class LlmDocumentClassification(BaseModel):
document_type: str = Field(default="other")
scene_code: str = Field(default="other")
scene_label: str = Field(default="其他票据")
expense_type: str = Field(default="other")
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
evidence: list[str] = Field(default_factory=list)
DEFAULT_RULE = DocumentRule(
document_type="other",
document_type_label="其他单据",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=(),
score_bias=0.0,
)
DOCUMENT_RULES: tuple[DocumentRule, ...] = (
DocumentRule(
document_type="flight_itinerary",
document_type_label="机票/航班行程单",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("电子行程单", "航班号", "航班", "机票", "登机", "航空", "客票"),
score_bias=0.34,
),
DocumentRule(
document_type="train_ticket",
document_type_label="火车/高铁票",
scene_code="travel",
scene_label="差旅票据",
expense_type="travel",
keywords=("高铁", "火车", "动车", "铁路", "车次", "检票", "二等座", "一等座"),
score_bias=0.32,
),
DocumentRule(
document_type="hotel_invoice",
document_type_label="酒店住宿票据",
scene_code="hotel",
scene_label="住宿票据",
expense_type="hotel",
keywords=("住宿", "房费", "客房", "入住", "离店", "酒店", "宾馆", "间夜"),
score_bias=0.16,
),
DocumentRule(
document_type="taxi_receipt",
document_type_label="出租车/网约车票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("滴滴出行", "滴滴", "网约车", "出租车", "打车", "快车", "专车", "订单号", "上车", "下车", "起点", "终点", "里程", "司机"),
score_bias=0.38,
),
DocumentRule(
document_type="parking_toll_receipt",
document_type_label="停车/通行费票据",
scene_code="transport",
scene_label="交通票据",
expense_type="transport",
keywords=("停车费", "通行费", "过路费", "收费站", "停车场", "停车"),
score_bias=0.28,
),
DocumentRule(
document_type="meal_receipt",
document_type_label="餐饮票据",
scene_code="meal",
scene_label="餐饮票据",
expense_type="meal",
keywords=("餐饮", "餐费", "用餐", "饭店", "酒楼", "餐厅", "食品", "外卖", "咖啡"),
score_bias=0.14,
),
DocumentRule(
document_type="office_invoice",
document_type_label="办公用品票据",
scene_code="office",
scene_label="办公用品票据",
expense_type="office",
keywords=("办公用品", "文具", "耗材", "打印纸", "墨盒", "硒鼓", "键盘", "鼠标"),
score_bias=0.14,
),
DocumentRule(
document_type="meeting_invoice",
document_type_label="会议/会务票据",
scene_code="meeting",
scene_label="会务票据",
expense_type="meeting",
keywords=("会议", "会务", "会展", "论坛", "会议室", "会场"),
score_bias=0.12,
),
DocumentRule(
document_type="training_invoice",
document_type_label="培训票据",
scene_code="training",
scene_label="培训票据",
expense_type="training",
keywords=("培训", "课程", "讲师", "教材", "学费", "认证"),
score_bias=0.12,
),
DocumentRule(
document_type="vat_invoice",
document_type_label="增值税发票",
scene_code="other",
scene_label="通用发票",
expense_type="other",
keywords=("发票代码", "发票号码", "价税合计", "增值税", "电子发票"),
score_bias=-0.08,
),
DocumentRule(
document_type="receipt",
document_type_label="一般收据/凭证",
scene_code="other",
scene_label="其他票据",
expense_type="other",
keywords=("收据", "凭证", "票据"),
score_bias=-0.18,
),
)
DOCUMENT_TYPE_RULE_MAP = {rule.document_type: rule for rule in DOCUMENT_RULES}
SUPPORTED_DOCUMENT_TYPES = tuple(DOCUMENT_TYPE_RULE_MAP.keys()) + ("other",)
AMOUNT_PATTERNS = (
re.compile(r"(?:价税合计|合计|金额|总额|票价|支付金额|实付金额|实收金额)[:\s¥¥]*([0-9]+(?:[.,][0-9]{1,2})?)"),
re.compile(r"([0-9]+(?:[.,][0-9]{1,2})?)\s*元"),
)
DATE_PATTERN = re.compile(r"((?:20\d{2}|19\d{2})[-/年.](?:1[0-2]|0?[1-9])[-/月.](?:3[01]|[12]\d|0?[1-9])日?)")
INVOICE_NUMBER_PATTERN = re.compile(r"(?:发票号码|票号|单号|订单号)[:\s]*([A-Za-z0-9-]{6,24})")
INVOICE_CODE_PATTERN = re.compile(r"(?:发票代码)[:\s]*([A-Za-z0-9-]{6,24})")
TRIP_NO_PATTERN = re.compile(r"(?:车次|航班(?:号)?)[:\s]*([A-Za-z0-9]{2,12})")
ROUTE_PATTERN = re.compile(r"([\u4e00-\u9fa5]{2,12})\s*(?:至|→|->|-)\s*([\u4e00-\u9fa5]{2,12})")
MERCHANT_PATTERNS = (
re.compile(r"(?:销售方(?:名称)?|商户(?:名称)?|开票方(?:名称)?|收款方(?:名称)?)[:\s]*([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40})"),
re.compile(r"([A-Za-z0-9\u4e00-\u9fa5()·&\\-]{2,40}(?:酒店|宾馆|饭店|酒楼|餐厅|航空|铁路|滴滴出行|停车场|服务区))"),
)
class DocumentIntelligenceService:
def __init__(self, db: Session | None = None) -> None:
self.runtime_chat_service = RuntimeChatService(db) if db is not None else None
def build_document_insight(
self,
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
raw_text = " ".join(
[str(filename or "").strip(), str(summary or "").strip(), str(text or "").strip()]
).strip()
compact = re.sub(r"\s+", "", raw_text).lower()
rule_match = _match_document_rule(compact)
base_rule = rule_match.rule or DEFAULT_RULE
fields = tuple(_extract_document_fields(raw_text))
rule_insight = DocumentInsight(
document_type=base_rule.document_type,
document_type_label=base_rule.document_type_label,
scene_code=base_rule.scene_code,
scene_label=base_rule.scene_label,
expense_type=base_rule.expense_type,
fields=fields,
classification_source="rule",
classification_confidence=rule_match.confidence,
evidence=rule_match.evidence,
)
llm_result = self._classify_with_model(
filename=str(filename or "").strip(),
summary=str(summary or "").strip(),
text=str(text or "").strip(),
preview_data_url=str(preview_data_url or "").strip(),
rule_insight=rule_insight,
fields=fields,
)
if llm_result is None:
return rule_insight
return self._merge_rule_and_model(
rule_insight=rule_insight,
llm_result=llm_result,
fields=fields,
has_preview=bool(preview_data_url),
)
def _classify_with_model(
self,
*,
filename: str,
summary: str,
text: str,
preview_data_url: str,
rule_insight: DocumentInsight,
fields: tuple[DocumentField, ...],
) -> tuple[str, LlmDocumentClassification] | None:
if self.runtime_chat_service is None:
return None
trimmed_text = text.strip()
if not trimmed_text and not summary.strip():
return None
facts = {
"filename": filename,
"summary": summary[:300],
"ocr_text_excerpt": trimmed_text[:2000],
"rule_candidate": {
"document_type": rule_insight.document_type,
"document_type_label": rule_insight.document_type_label,
"scene_code": rule_insight.scene_code,
"scene_label": rule_insight.scene_label,
"expense_type": rule_insight.expense_type,
"confidence": round(rule_insight.classification_confidence, 2),
"evidence": list(rule_insight.evidence),
},
"extracted_fields": [
{"key": field.key, "label": field.label, "value": field.value}
for field in fields
],
"allowed_document_types": list(SUPPORTED_DOCUMENT_TYPES),
}
system_prompt = (
"你是企业报销票据识别复核器。"
"你的任务不是 OCR而是在已有 OCR 文本和票据预览基础上判断票据类型。"
"只输出 JSON 对象,不要输出 Markdown、解释或代码块。"
"document_type 只能是:"
f"{', '.join(SUPPORTED_DOCUMENT_TYPES)}"
"如果证据不足,返回 other。"
"严禁编造 OCR 中不存在的商户、酒店、航司、路线或金额。"
"如果 OCR 出现冲突碎片,应优先依据票据主体信息,而不是单个噪声词。"
"例如滴滴行程单/网约车发票,即使 OCR 混入酒店名称,也不能直接判成酒店票据。"
"输出字段document_type, scene_code, scene_label, expense_type, confidence, evidence。"
)
user_prompt = (
"请根据以下票据事实给出最终分类 JSON\n"
f"{json.dumps(facts, ensure_ascii=False, indent=2)}\n\n"
"示例输出:\n"
"{\n"
' "document_type": "taxi_receipt",\n'
' "scene_code": "transport",\n'
' "scene_label": "交通票据",\n'
' "expense_type": "transport",\n'
' "confidence": 0.86,\n'
' "evidence": ["OCR 中出现 滴滴出行、订单号、上车/下车 等交通特征"]\n'
"}"
)
if preview_data_url:
response_text = self.runtime_chat_service.complete(
[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": preview_data_url}},
],
},
],
slot_priority=("vlm",),
max_tokens=320,
temperature=0.0,
)
parsed = self._parse_llm_payload(response_text)
if parsed is not None:
return "llm_vision", parsed
response_text = self.runtime_chat_service.complete(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
slot_priority=("main", "backup"),
max_tokens=320,
temperature=0.0,
)
parsed = self._parse_llm_payload(response_text)
if parsed is not None:
return "llm_text", parsed
return None
@staticmethod
def _parse_llm_payload(response_text: str | None) -> LlmDocumentClassification | None:
payload_json = _extract_json_payload(response_text)
if payload_json is None:
return None
try:
parsed = LlmDocumentClassification.model_validate(payload_json)
except ValidationError:
return None
normalized_type = str(parsed.document_type or "other").strip().lower() or "other"
if normalized_type not in SUPPORTED_DOCUMENT_TYPES:
normalized_type = "other"
base_rule = DOCUMENT_TYPE_RULE_MAP.get(normalized_type, DEFAULT_RULE)
evidence = [
str(item or "").strip()
for item in parsed.evidence
if str(item or "").strip()
][:4]
return LlmDocumentClassification(
document_type=normalized_type,
scene_code=str(parsed.scene_code or base_rule.scene_code).strip() or base_rule.scene_code,
scene_label=str(parsed.scene_label or base_rule.scene_label).strip() or base_rule.scene_label,
expense_type=str(parsed.expense_type or base_rule.expense_type).strip() or base_rule.expense_type,
confidence=float(parsed.confidence),
evidence=evidence,
)
@staticmethod
def _merge_rule_and_model(
*,
rule_insight: DocumentInsight,
llm_result: tuple[str, LlmDocumentClassification],
fields: tuple[DocumentField, ...],
has_preview: bool,
) -> DocumentInsight:
source, parsed = llm_result
if parsed.confidence < 0.55:
return rule_insight
should_override = False
if parsed.document_type == rule_insight.document_type:
should_override = True
elif rule_insight.document_type == "other" and parsed.document_type != "other":
should_override = True
elif parsed.document_type != "other":
threshold = 0.60 if has_preview else max(0.76, rule_insight.classification_confidence + 0.12)
should_override = parsed.confidence >= threshold
if not should_override:
return rule_insight
rule = DOCUMENT_TYPE_RULE_MAP.get(parsed.document_type, DEFAULT_RULE)
warnings = list(rule_insight.warnings)
if parsed.document_type != rule_insight.document_type:
warnings.append("票据类型已结合大模型复核结果修正,建议人工再核对原图。")
return DocumentInsight(
document_type=rule.document_type,
document_type_label=rule.document_type_label,
scene_code=rule.scene_code if parsed.scene_code == "other" else parsed.scene_code,
scene_label=rule.scene_label if parsed.scene_label == "其他票据" else parsed.scene_label,
expense_type=rule.expense_type if parsed.expense_type == "other" else parsed.expense_type,
fields=fields,
classification_source=source,
classification_confidence=max(parsed.confidence, rule_insight.classification_confidence),
evidence=tuple(parsed.evidence or rule_insight.evidence),
warnings=tuple(warnings),
)
def build_document_insight(
*,
filename: str = "",
summary: str = "",
text: str = "",
preview_data_url: str = "",
) -> DocumentInsight:
return DocumentIntelligenceService().build_document_insight(
filename=filename,
summary=summary,
text=text,
preview_data_url=preview_data_url,
)
def _match_document_rule(compact_text: str) -> RuleMatch:
best_rule = DEFAULT_RULE
best_evidence: tuple[str, ...] = ()
best_score = 0.0
for rule in DOCUMENT_RULES:
matched = tuple(keyword for keyword in rule.keywords if keyword.lower() in compact_text)
if not matched:
continue
score = float(rule.score_bias) + len(matched) * 0.92 + sum(min(len(keyword), 6) * 0.08 for keyword in matched)
if score > best_score:
best_rule = rule
best_evidence = matched
best_score = score
if best_score <= 0:
return RuleMatch(rule=None, confidence=0.0, evidence=(), score=0.0)
confidence = min(0.94, 0.30 + min(best_score, 4.8) * 0.12)
return RuleMatch(
rule=best_rule,
confidence=round(confidence, 2),
evidence=best_evidence[:4],
score=best_score,
)
def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
if not response_text:
return None
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
if not cleaned:
return None
fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL)
candidates = [fenced_match.group(1)] if fenced_match else []
candidates.append(cleaned)
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidates.append(cleaned[start : end + 1])
for candidate in candidates:
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def _extract_document_fields(text: str) -> list[DocumentField]:
fields: list[DocumentField] = []
amount = _extract_amount(text)
if amount:
fields.append(DocumentField(key="amount", label="金额", value=amount))
date_value = _extract_date(text)
if date_value:
fields.append(DocumentField(key="date", label="日期", value=date_value))
merchant = _extract_merchant(text)
if merchant:
fields.append(DocumentField(key="merchant_name", label="商户", value=merchant))
invoice_number = _extract_pattern(INVOICE_NUMBER_PATTERN, text)
if invoice_number:
fields.append(DocumentField(key="invoice_number", label="票据号码", value=invoice_number))
invoice_code = _extract_pattern(INVOICE_CODE_PATTERN, text)
if invoice_code:
fields.append(DocumentField(key="invoice_code", label="发票代码", value=invoice_code))
trip_no = _extract_pattern(TRIP_NO_PATTERN, text)
if trip_no:
fields.append(DocumentField(key="trip_no", label="车次/航班", value=trip_no))
route = _extract_route(text)
if route:
fields.append(DocumentField(key="route", label="行程", value=route))
return fields
def _extract_amount(text: str) -> str:
best_value: Decimal | None = None
for pattern in AMOUNT_PATTERNS:
for match in pattern.finditer(text):
raw_value = str(match.group(1) or "").replace(",", ".").strip()
try:
candidate = Decimal(raw_value)
except InvalidOperation:
continue
if candidate <= Decimal("0.00"):
continue
if best_value is None or candidate > best_value:
best_value = candidate
if best_value is not None:
break
if best_value is None:
return ""
normalized = best_value.quantize(Decimal("0.01"))
text_value = format(normalized, "f").rstrip("0").rstrip(".")
return f"{text_value}"
def _extract_date(text: str) -> str:
match = DATE_PATTERN.search(text)
if not match:
return ""
raw_value = str(match.group(1) or "").strip()
normalized = raw_value.replace("", "-").replace("", "-").replace("", "")
normalized = normalized.replace("/", "-").replace(".", "-")
parts = [part for part in normalized.split("-") if part]
if len(parts) != 3:
return raw_value
year, month, day = parts
return f"{year.zfill(4)}-{month.zfill(2)}-{day.zfill(2)}"
def _extract_merchant(text: str) -> str:
for pattern in MERCHANT_PATTERNS:
match = pattern.search(text)
if not match:
continue
value = _clean_field_value(match.group(1))
if value:
return value
return ""
def _extract_route(text: str) -> str:
match = ROUTE_PATTERN.search(text)
if not match:
return ""
start = _clean_field_value(match.group(1))
end = _clean_field_value(match.group(2))
if not start or not end or start == end:
return ""
return f"{start}-{end}"
def _extract_pattern(pattern: re.Pattern[str], text: str) -> str:
match = pattern.search(text)
if not match:
return ""
return _clean_field_value(match.group(1))
def _clean_field_value(value: str) -> str:
return str(value or "").strip().strip(":,。.;")

View File

@@ -1,21 +1,55 @@
from __future__ import annotations
import base64
import json
import shutil
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from uuid import uuid4
from sqlalchemy.orm import Session
from app.core.config import SERVER_DIR, get_settings
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
from app.services.document_intelligence import DocumentIntelligenceService
WORKER_JSON_PREFIX = "__OCR_JSON__="
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
@dataclass(slots=True)
class PreparedOcrInput:
input_path: Path
source_key: str
filename: str
media_type: str
page_index: int | None = None
preview_kind: str = ""
preview_data_url: str = ""
@dataclass(slots=True)
class AggregatedOcrDocument:
filename: str
media_type: str
source_key: str
engine: str = "paddleocr_mobile"
model: str = "PP-OCRv5_mobile"
summary_fragments: list[str] = field(default_factory=list)
text_fragments: list[str] = field(default_factory=list)
score_values: list[float] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
lines: list[OcrRecognizeLineRead] = field(default_factory=list)
page_count: int = 0
preview_kind: str = ""
preview_data_url: str = ""
class OcrService:
def __init__(self) -> None:
def __init__(self, db: Session | None = None) -> None:
self.settings = get_settings()
self.document_intelligence_service = DocumentIntelligenceService(db)
def recognize_files(
self,
@@ -28,10 +62,11 @@ class OcrService:
temp_root.mkdir(parents=True, exist_ok=True)
documents: list[OcrRecognizeDocumentRead] = []
input_paths: list[Path] = []
meta_by_path: dict[str, tuple[str, str]] = {}
prepared_inputs: list[PreparedOcrInput] = []
cleanup_paths: list[Path] = []
python_bin = self._resolve_python_bin()
worker_path = self._resolve_worker_path()
worker_payload: dict = {}
try:
for filename, content, media_type in files:
@@ -73,17 +108,55 @@ class OcrService:
temp_path = temp_root / f"{uuid4().hex}{suffix}"
temp_path.write_bytes(content)
input_paths.append(temp_path)
meta_by_path[str(temp_path)] = (normalized_name, resolved_media_type)
cleanup_paths.append(temp_path)
if input_paths:
if suffix == ".pdf":
try:
prepared_inputs.extend(
self._prepare_pdf_inputs(
pdf_path=temp_path,
filename=normalized_name,
media_type=resolved_media_type,
cleanup_paths=cleanup_paths,
)
)
except RuntimeError as exc:
documents.append(
OcrRecognizeDocumentRead(
filename=normalized_name,
media_type=resolved_media_type,
warnings=[str(exc)],
)
)
continue
prepared_inputs.append(
PreparedOcrInput(
input_path=temp_path,
source_key=uuid4().hex,
filename=normalized_name,
media_type=resolved_media_type,
preview_kind="image" if resolved_media_type.startswith("image/") else "",
preview_data_url=(
self._build_preview_data_url(temp_path, media_type=resolved_media_type)
if resolved_media_type.startswith("image/")
else ""
),
)
)
if prepared_inputs:
worker_payload = self._invoke_worker(
python_bin=python_bin,
worker_path=worker_path,
input_paths=input_paths,
input_paths=[item.input_path for item in prepared_inputs],
)
documents.extend(
self._build_documents(
worker_documents=worker_payload.get("documents", []),
prepared_inputs=prepared_inputs,
)
)
for item in worker_payload.get("documents", []):
documents.append(self._build_document(item, meta_by_path))
success_count = sum(
1
@@ -92,12 +165,12 @@ class OcrService:
)
engine = (
str(worker_payload.get("engine", "paddleocr_mobile"))
if input_paths
if prepared_inputs
else "paddleocr_mobile"
)
model = (
str(worker_payload.get("model", "PP-OCRv5_mobile"))
if input_paths
if prepared_inputs
else "PP-OCRv5_mobile"
)
return OcrRecognizeBatchRead(
@@ -108,8 +181,7 @@ class OcrService:
documents=documents,
)
finally:
for path in input_paths:
path.unlink(missing_ok=True)
self._cleanup_temp_paths(cleanup_paths)
def _resolve_python_bin(self) -> str:
candidates = []
@@ -182,17 +254,182 @@ class OcrService:
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
return None
@staticmethod
def _build_document(
payload: dict,
meta_by_path: dict[str, tuple[str, str]],
) -> OcrRecognizeDocumentRead:
input_path = str(payload.get("input_path") or "")
filename, media_type = meta_by_path.get(
input_path,
(Path(input_path).name or "upload.bin", "application/octet-stream"),
def _prepare_pdf_inputs(
self,
*,
pdf_path: Path,
filename: str,
media_type: str,
cleanup_paths: list[Path],
) -> list[PreparedOcrInput]:
output_dir = pdf_path.with_suffix("")
output_dir.mkdir(parents=True, exist_ok=True)
cleanup_paths.append(output_dir)
image_paths = self._convert_pdf_to_images(pdf_path=pdf_path, output_dir=output_dir)
if not image_paths:
raise RuntimeError("PDF 转图片后未生成可识别页面。")
preview_data_url = self._build_preview_data_url(image_paths[0], media_type="image/png")
source_key = uuid4().hex
descriptors: list[PreparedOcrInput] = []
for page_index, image_path in enumerate(image_paths):
descriptors.append(
PreparedOcrInput(
input_path=image_path,
source_key=source_key,
filename=filename,
media_type=media_type,
page_index=page_index,
preview_kind="image" if page_index == 0 else "",
preview_data_url=preview_data_url if page_index == 0 else "",
)
lines = [
)
return descriptors
def _convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
prefix = output_dir / "page"
completed = subprocess.run(
[
"pdftoppm",
"-png",
"-r",
"160",
str(pdf_path),
str(prefix),
],
capture_output=True,
text=True,
timeout=self.settings.ocr_timeout_seconds,
check=False,
)
if completed.returncode != 0:
detail = (completed.stderr or completed.stdout or "").strip()
raise RuntimeError(f"PDF 转图片失败:{detail or 'pdftoppm 返回非 0 状态码。'}")
return sorted(output_dir.glob("page-*.png"), key=self._extract_pdf_page_sort_key)
@staticmethod
def _extract_pdf_page_sort_key(path: Path) -> tuple[int, str]:
suffix = path.stem.rsplit("-", 1)[-1]
try:
return int(suffix), path.name
except ValueError:
return 0, path.name
@staticmethod
def _build_preview_data_url(path: Path, *, media_type: str) -> str:
encoded = base64.b64encode(path.read_bytes()).decode("ascii")
return f"data:{media_type};base64,{encoded}"
def _build_documents(
self,
*,
worker_documents: list[dict],
prepared_inputs: list[PreparedOcrInput],
) -> list[OcrRecognizeDocumentRead]:
descriptor_by_path = {str(item.input_path): item for item in prepared_inputs}
source_order: list[str] = []
seen_sources: set[str] = set()
for item in prepared_inputs:
if item.source_key in seen_sources:
continue
seen_sources.add(item.source_key)
source_order.append(item.source_key)
aggregated_by_source: dict[str, AggregatedOcrDocument] = {}
for payload in worker_documents:
if not isinstance(payload, dict):
continue
input_path = str(payload.get("input_path") or "")
descriptor = descriptor_by_path.get(input_path)
if descriptor is None:
continue
aggregated = aggregated_by_source.get(descriptor.source_key)
if aggregated is None:
aggregated = AggregatedOcrDocument(
filename=descriptor.filename,
media_type=descriptor.media_type,
source_key=descriptor.source_key,
engine=str(payload.get("engine", "paddleocr_mobile")),
model=str(payload.get("model", "PP-OCRv5_mobile")),
)
aggregated_by_source[descriptor.source_key] = aggregated
aggregated.page_count = max(
aggregated.page_count,
(descriptor.page_index + 1)
if descriptor.page_index is not None
else int(payload.get("page_count", 1) or 1),
)
if descriptor.preview_kind and not aggregated.preview_kind:
aggregated.preview_kind = descriptor.preview_kind
if descriptor.preview_data_url and not aggregated.preview_data_url:
aggregated.preview_data_url = descriptor.preview_data_url
page_summary = str(payload.get("summary", "") or "").strip()
if page_summary:
aggregated.summary_fragments.append(page_summary)
page_text = str(payload.get("text", "") or "").strip()
if page_text:
aggregated.text_fragments.append(page_text)
lines = self._build_lines(
payload.get("lines", []),
page_index_override=descriptor.page_index,
)
aggregated.lines.extend(lines)
aggregated.score_values.extend(line.score for line in lines if line.score > 0)
if not lines:
avg_score = float(payload.get("avg_score", 0.0) or 0.0)
if avg_score > 0:
aggregated.score_values.append(avg_score)
for warning in payload.get("warnings", []):
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in aggregated.warnings:
aggregated.warnings.append(normalized_warning)
documents: list[OcrRecognizeDocumentRead] = []
for source_key in source_order:
descriptors = [item for item in prepared_inputs if item.source_key == source_key]
if not descriptors:
continue
aggregated = aggregated_by_source.get(source_key)
if aggregated is None:
first_descriptor = descriptors[0]
documents.append(
OcrRecognizeDocumentRead(
filename=first_descriptor.filename,
media_type=first_descriptor.media_type,
page_count=max(1, len(descriptors)),
preview_kind=first_descriptor.preview_kind,
preview_data_url=first_descriptor.preview_data_url,
warnings=["OCR worker 未返回该文件的识别结果。"],
)
)
continue
documents.append(self._finalize_document(aggregated))
return documents
@staticmethod
def _build_lines(
items: list[dict],
*,
page_index_override: int | None = None,
) -> list[OcrRecognizeLineRead]:
lines: list[OcrRecognizeLineRead] = []
for item in items:
if not isinstance(item, dict):
continue
page_index = page_index_override
if page_index is None and item.get("page_index") is not None:
page_index = int(item["page_index"])
lines.append(
OcrRecognizeLineRead(
text=str(item.get("text", "")),
score=float(item.get("score", 0.0) or 0.0),
@@ -201,21 +438,74 @@ class OcrService:
for point in item.get("box", [])
if isinstance(point, list) and len(point) == 2
],
page_index=int(item["page_index"]) if item.get("page_index") is not None else None,
page_index=page_index,
)
for item in payload.get("lines", [])
if isinstance(item, dict)
]
)
return lines
@staticmethod
def _truncate_summary(parts: list[str]) -> str:
summary = "".join([part for part in parts if part][:3])
if len(summary) > 180:
return f"{summary[:177]}..."
return summary
def _finalize_document(self, aggregated: AggregatedOcrDocument) -> OcrRecognizeDocumentRead:
full_text = "\n".join(fragment for fragment in aggregated.text_fragments if fragment).strip()
summary = self._truncate_summary(aggregated.summary_fragments or aggregated.text_fragments)
insight = self.document_intelligence_service.build_document_insight(
filename=aggregated.filename,
summary=summary,
text=full_text,
preview_data_url=aggregated.preview_data_url,
)
warnings = list(aggregated.warnings)
for warning in insight.warnings:
normalized_warning = str(warning or "").strip()
if normalized_warning and normalized_warning not in warnings:
warnings.append(normalized_warning)
return OcrRecognizeDocumentRead(
filename=filename,
media_type=media_type,
engine=str(payload.get("engine", "paddleocr_mobile")),
model=str(payload.get("model", "PP-OCRv5_mobile")),
text=str(payload.get("text", "")),
summary=str(payload.get("summary", "")),
avg_score=float(payload.get("avg_score", 0.0) or 0.0),
line_count=int(payload.get("line_count", len(lines)) or 0),
page_count=int(payload.get("page_count", 1) or 1),
warnings=[str(item) for item in payload.get("warnings", [])],
lines=lines,
filename=aggregated.filename,
media_type=aggregated.media_type,
engine=aggregated.engine,
model=aggregated.model,
text=full_text,
summary=summary,
avg_score=(
sum(aggregated.score_values) / len(aggregated.score_values)
if aggregated.score_values
else 0.0
),
line_count=len(aggregated.lines),
page_count=max(1, aggregated.page_count),
document_type=insight.document_type,
document_type_label=insight.document_type_label,
scene_code=insight.scene_code,
scene_label=insight.scene_label,
classification_source=insight.classification_source,
classification_confidence=insight.classification_confidence,
classification_evidence=list(insight.evidence),
document_fields=[
OcrRecognizeFieldRead(
key=field.key,
label=field.label,
value=field.value,
)
for field in insight.fields
],
preview_kind=aggregated.preview_kind,
preview_data_url=aggregated.preview_data_url,
warnings=warnings,
lines=sorted(
aggregated.lines,
key=lambda item: item.page_index if item.page_index is not None else -1,
),
)
@staticmethod
def _cleanup_temp_paths(paths: list[Path]) -> None:
for path in reversed(paths):
if path.is_dir():
shutil.rmtree(path, ignore_errors=True)
continue
path.unlink(missing_ok=True)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import json
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.services.document_intelligence import DocumentIntelligenceService, build_document_insight
from app.services.runtime_chat import RuntimeChatService
def test_build_document_insight_prefers_transport_for_didi_text_with_hotel_noise() -> None:
insight = build_document_insight(
filename="didi-trip.png",
summary="滴滴出行行程单",
text="滴滴出行电子发票 订单号 12345678 上车点 深圳湾 下车点 后海 全季酒店 里程 12.4 公里 金额 48 元",
)
assert insight.document_type == "taxi_receipt"
assert insight.document_type_label == "出租车/网约车票据"
assert insight.scene_code == "transport"
assert any(field.label == "金额" and field.value == "48元" for field in insight.fields)
def test_document_intelligence_service_uses_vlm_result_when_preview_available(monkeypatch) -> None:
calls: list[tuple[str, ...]] = []
def fake_complete(self, messages, *, slot_priority=("main", "backup"), max_tokens=500, temperature=0.2):
calls.append(slot_priority)
if slot_priority == ("vlm",):
assert isinstance(messages[1]["content"], list)
return json.dumps(
{
"document_type": "taxi_receipt",
"scene_code": "transport",
"scene_label": "交通票据",
"expense_type": "transport",
"confidence": 0.91,
"evidence": ["图片主体为滴滴行程单OCR 中出现订单号、上车、下车等字段"],
},
ensure_ascii=False,
)
return None
monkeypatch.setattr(RuntimeChatService, "complete", fake_complete)
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
session = sessionmaker(bind=engine, autoflush=False, autocommit=False)()
try:
insight = DocumentIntelligenceService(session).build_document_insight(
filename="mixed-noise.png",
summary="OCR 混入酒店名称",
text="全季酒店 滴滴出行 订单号 12345678 上车 下车 金额 52 元",
preview_data_url="data:image/png;base64,ZmFrZQ==",
)
finally:
session.close()
assert insight.document_type == "taxi_receipt"
assert insight.classification_source == "llm_vision"
assert calls[0] == ("vlm",)

View File

@@ -10,7 +10,7 @@ from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeFieldRead, OcrRecognizeLineRead
from app.services.ocr import OcrService
@@ -50,14 +50,23 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None:
OcrRecognizeDocumentRead(
filename="invoice.png",
media_type="image/png",
text="发票金额 100 元",
summary="发票金额 100 元",
text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
summary="增值税电子发票金额 100 元",
avg_score=0.98,
line_count=1,
page_count=1,
document_type="vat_invoice",
document_type_label="增值税发票",
scene_code="other",
scene_label="通用发票",
document_fields=[
OcrRecognizeFieldRead(key="amount", label="金额", value="100元"),
OcrRecognizeFieldRead(key="date", label="日期", value="2026-05-13"),
OcrRecognizeFieldRead(key="invoice_number", label="票据号码", value="12345678"),
],
lines=[
OcrRecognizeLineRead(
text="发票金额 100 元",
text="增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
score=0.98,
box=[[1, 2], [10, 2], [10, 8], [1, 8]],
page_index=0,
@@ -81,4 +90,7 @@ def test_ocr_recognize_endpoint_returns_structured_payload(monkeypatch) -> None:
assert payload["engine"] == "paddleocr_mobile"
assert payload["success_count"] == 1
assert payload["documents"][0]["filename"] == "invoice.png"
assert payload["documents"][0]["summary"] == "发票金额 100 元"
assert payload["documents"][0]["summary"] == "增值税电子发票金额 100 元"
assert payload["documents"][0]["document_type"] == "vat_invoice"
assert payload["documents"][0]["document_type_label"] == "增值税发票"
assert payload["documents"][0]["document_fields"][0]["label"] == "金额"

View File

@@ -26,15 +26,15 @@ for index, arg in enumerate(sys.argv):
"input_path": input_path,
"engine": "paddleocr_mobile",
"model": "PP-OCRv5_mobile",
"text": "发票金额 100 元",
"summary": "发票金额 100 元",
"text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
"summary": "增值税电子发票金额 100 元",
"avg_score": 0.98,
"line_count": 1,
"page_count": 1,
"warnings": [],
"lines": [
{
"text": "发票金额 100 元",
"text": "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13",
"score": 0.98,
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
"page_index": 0,
@@ -74,10 +74,106 @@ print("__OCR_JSON__=" + json.dumps(payload, ensure_ascii=False))
assert len(result.documents) == 2
recognized = next(item for item in result.documents if item.filename == "invoice.png")
assert recognized.summary == "发票金额 100 元"
assert recognized.summary == "增值税电子发票金额 100 元"
assert recognized.line_count == 1
assert recognized.lines[0].text == "发票金额 100 元"
assert recognized.document_type == "vat_invoice"
assert recognized.document_type_label == "增值税发票"
assert any(field.label == "金额" and field.value == "100元" for field in recognized.document_fields)
assert any(field.label == "票据号码" and field.value == "12345678" for field in recognized.document_fields)
assert any(field.label == "日期" and field.value == "2026-05-13" for field in recognized.document_fields)
assert recognized.lines[0].text == "增值税电子发票 发票号码12345678 金额 100 元 2026-05-13"
skipped = next(item for item in result.documents if item.filename == "notes.txt")
assert skipped.line_count == 0
assert skipped.warnings == ["当前仅支持图片和 PDF 文件进行 OCR。"]
def test_ocr_service_converts_pdf_to_images_and_returns_image_preview(
monkeypatch,
tmp_path: Path,
) -> None:
def fake_convert_pdf_to_images(self, *, pdf_path: Path, output_dir: Path) -> list[Path]:
first = output_dir / "page-1.png"
second = output_dir / "page-2.png"
first.write_bytes(b"fake-page-1")
second.write_bytes(b"fake-page-2")
return [first, second]
def fake_invoke_worker(
self,
*,
python_bin: str,
worker_path: str,
input_paths: list[Path],
) -> dict:
assert [path.name for path in input_paths] == ["page-1.png", "page-2.png"]
return {
"engine": "paddleocr_mobile",
"model": "PP-OCRv5_mobile",
"documents": [
{
"input_path": str(input_paths[0]),
"engine": "paddleocr_mobile",
"model": "PP-OCRv5_mobile",
"text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元",
"summary": "高铁票第一页",
"avg_score": 0.97,
"line_count": 1,
"page_count": 1,
"warnings": [],
"lines": [
{
"text": "高铁票 深圳北-广州南 车次 G1234 2026-05-13 金额 188 元",
"score": 0.97,
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
}
],
},
{
"input_path": str(input_paths[1]),
"engine": "paddleocr_mobile",
"model": "PP-OCRv5_mobile",
"text": "乘车人 张三",
"summary": "高铁票第二页",
"avg_score": 0.94,
"line_count": 1,
"page_count": 1,
"warnings": [],
"lines": [
{
"text": "乘车人 张三",
"score": 0.94,
"box": [[1, 2], [10, 2], [10, 8], [1, 8]],
}
],
},
],
}
monkeypatch.setenv("STORAGE_ROOT_DIR", str(tmp_path / "storage"))
monkeypatch.setattr(OcrService, "_resolve_python_bin", lambda self: "python")
monkeypatch.setattr(OcrService, "_resolve_worker_path", lambda self: "worker.py")
monkeypatch.setattr(OcrService, "_convert_pdf_to_images", fake_convert_pdf_to_images)
monkeypatch.setattr(OcrService, "_invoke_worker", fake_invoke_worker)
get_settings.cache_clear()
try:
result = OcrService().recognize_files(
[
("train-ticket.pdf", b"%PDF-1.4 fake", "application/pdf"),
]
)
finally:
get_settings.cache_clear()
assert result.success_count == 1
assert len(result.documents) == 1
recognized = result.documents[0]
assert recognized.filename == "train-ticket.pdf"
assert recognized.page_count == 2
assert recognized.preview_kind == "image"
assert recognized.preview_data_url.startswith("data:image/png;base64,")
assert recognized.document_type == "train_ticket"
assert any(field.label == "金额" and field.value == "188元" for field in recognized.document_fields)
assert any(field.label == "车次/航班" and field.value == "G1234" for field in recognized.document_fields)
assert recognized.lines[0].page_index == 0
assert recognized.lines[1].page_index == 1