feat(server): add OCR invoice processing functionality
New endpoints: - server/src/app/api/v1/endpoints/ocr.py: OCR API endpoints for invoice scanning New schemas: - server/src/app/schemas/ocr.py: OCR request/response data schemas New services: - server/src/app/services/ocr.py: OCR processing business logic - server/src/app/services/expense_claims.py: expense claims management service Scripts: - server/scripts/bootstrap_paddleocr_mobile.sh: PaddleOCR mobile setup script - server/scripts/paddle_ocr_worker.py: PaddleOCR worker process
This commit is contained in:
20
server/scripts/bootstrap_paddleocr_mobile.sh
Normal file
20
server/scripts/bootstrap_paddleocr_mobile.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
OCR_VENV_DIR="${ROOT_DIR}/.venv-ocr312"
|
||||
PYTHON_BIN="${PYTHON_BIN:-python3.12}"
|
||||
|
||||
if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then
|
||||
echo "python3.12 不存在,请先安装 Python 3.12。" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
apt-get update
|
||||
apt-get install -y libgl1 libglib2.0-0
|
||||
|
||||
"${PYTHON_BIN}" -m venv "${OCR_VENV_DIR}"
|
||||
"${OCR_VENV_DIR}/bin/pip" install --upgrade pip
|
||||
"${OCR_VENV_DIR}/bin/pip" install "paddlepaddle==3.2.0" "paddleocr==3.5.0"
|
||||
|
||||
echo "PaddleOCR mobile runtime 已安装到 ${OCR_VENV_DIR}"
|
||||
126
server/scripts/paddle_ocr_worker.py
Normal file
126
server/scripts/paddle_ocr_worker.py
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from statistics import fmean
|
||||
from typing import Any
|
||||
|
||||
os.environ.setdefault("PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK", "True")
|
||||
|
||||
from paddleocr import PaddleOCR # noqa: E402
|
||||
|
||||
WORKER_JSON_PREFIX = "__OCR_JSON__="
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Run PaddleOCR mobile worker.")
|
||||
parser.add_argument("--input", action="append", dest="inputs", required=True)
|
||||
parser.add_argument("--lang", default="ch")
|
||||
parser.add_argument("--text-detection-model", default="PP-OCRv5_mobile_det")
|
||||
parser.add_argument("--text-recognition-model", default="PP-OCRv5_mobile_rec")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def coerce_box(box: Any) -> list[list[int]]:
|
||||
if not isinstance(box, list):
|
||||
return []
|
||||
points: list[list[int]] = []
|
||||
for point in box:
|
||||
if not isinstance(point, list) or len(point) != 2:
|
||||
continue
|
||||
points.append([int(point[0]), int(point[1])])
|
||||
return points
|
||||
|
||||
|
||||
def build_document(input_path: str, results: list[Any]) -> dict[str, Any]:
|
||||
lines: list[dict[str, Any]] = []
|
||||
all_texts: list[str] = []
|
||||
all_scores: list[float] = []
|
||||
|
||||
for fallback_page_index, result in enumerate(results):
|
||||
payload = result.json
|
||||
if isinstance(payload, str):
|
||||
payload = json.loads(payload)
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
res = payload.get("res", payload)
|
||||
if not isinstance(res, dict):
|
||||
continue
|
||||
|
||||
page_index = res.get("page_index")
|
||||
if page_index is None:
|
||||
page_index = fallback_page_index if len(results) > 1 else None
|
||||
|
||||
texts = res.get("rec_texts", [])
|
||||
scores = res.get("rec_scores", [])
|
||||
boxes = res.get("rec_polys") or res.get("dt_polys") or []
|
||||
|
||||
for index, text in enumerate(texts):
|
||||
normalized_text = str(text or "").strip()
|
||||
if not normalized_text:
|
||||
continue
|
||||
score = float(scores[index] if index < len(scores) else 0.0)
|
||||
box = coerce_box(boxes[index] if index < len(boxes) else [])
|
||||
lines.append(
|
||||
{
|
||||
"text": normalized_text,
|
||||
"score": score,
|
||||
"box": box,
|
||||
"page_index": page_index,
|
||||
}
|
||||
)
|
||||
all_texts.append(normalized_text)
|
||||
all_scores.append(score)
|
||||
|
||||
summary = ";".join(all_texts[:3])
|
||||
if len(summary) > 180:
|
||||
summary = f"{summary[:177]}..."
|
||||
|
||||
warnings: list[str] = []
|
||||
if not lines:
|
||||
warnings.append("未识别到可用文本。")
|
||||
|
||||
return {
|
||||
"input_path": input_path,
|
||||
"engine": "paddleocr_mobile",
|
||||
"model": "PP-OCRv5_mobile",
|
||||
"text": "\n".join(all_texts),
|
||||
"summary": summary,
|
||||
"avg_score": float(fmean(all_scores)) if all_scores else 0.0,
|
||||
"line_count": len(lines),
|
||||
"page_count": len(results),
|
||||
"warnings": warnings,
|
||||
"lines": lines,
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
ocr = PaddleOCR(
|
||||
text_detection_model_name=args.text_detection_model,
|
||||
text_recognition_model_name=args.text_recognition_model,
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
use_textline_orientation=False,
|
||||
lang=args.lang,
|
||||
)
|
||||
|
||||
documents = []
|
||||
for input_path in args.inputs:
|
||||
results = ocr.predict(input_path)
|
||||
documents.append(build_document(input_path, results))
|
||||
|
||||
payload = {
|
||||
"engine": "paddleocr_mobile",
|
||||
"model": "PP-OCRv5_mobile",
|
||||
"documents": documents,
|
||||
}
|
||||
print(f"{WORKER_JSON_PREFIX}{json.dumps(payload, ensure_ascii=False)}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
56
server/src/app/api/v1/endpoints/ocr.py
Normal file
56
server/src/app/api/v1/endpoints/ocr.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
|
||||
from app.api.deps import CurrentUserContext, get_current_user
|
||||
from app.schemas.common import ErrorResponse
|
||||
from app.schemas.ocr import OcrRecognizeBatchRead
|
||||
from app.services.ocr import OcrService
|
||||
|
||||
router = APIRouter(prefix="/ocr")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/recognize",
|
||||
response_model=OcrRecognizeBatchRead,
|
||||
summary="识别票据或图片 OCR",
|
||||
description="使用 PaddleOCR mobile 模型对上传的图片或 PDF 执行 OCR,并返回结构化文本摘要。",
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorResponse,
|
||||
"description": "未上传文件或文件参数非法。",
|
||||
},
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"model": ErrorResponse,
|
||||
"description": "未提供当前登录用户。",
|
||||
},
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE: {
|
||||
"model": ErrorResponse,
|
||||
"description": "OCR 运行时不可用或执行失败。",
|
||||
},
|
||||
},
|
||||
)
|
||||
async def recognize_ocr_documents(
|
||||
files: Annotated[list[UploadFile], File(description="待识别的票据图片或 PDF。")],
|
||||
_: Annotated[CurrentUserContext, Depends(get_current_user)],
|
||||
) -> OcrRecognizeBatchRead:
|
||||
try:
|
||||
payload = []
|
||||
for upload in files:
|
||||
payload.append(
|
||||
(
|
||||
str(upload.filename or "upload.bin"),
|
||||
await upload.read(),
|
||||
upload.content_type,
|
||||
)
|
||||
)
|
||||
return OcrService().recognize_files(payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
35
server/src/app/schemas/ocr.py
Normal file
35
server/src/app/schemas/ocr.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OcrRecognizeLineRead(BaseModel):
|
||||
text: str = Field(description="识别出的文本行。")
|
||||
score: float = Field(default=0.0, ge=0.0, le=1.0, description="该行识别置信度。")
|
||||
box: list[list[int]] = Field(default_factory=list, description="文本框坐标。")
|
||||
page_index: int | None = Field(default=None, description="页码,从 0 开始。")
|
||||
|
||||
|
||||
class OcrRecognizeDocumentRead(BaseModel):
|
||||
filename: str = Field(description="原始文件名。")
|
||||
media_type: str = Field(description="文件媒体类型。")
|
||||
engine: str = Field(default="paddleocr_mobile", description="使用的 OCR 引擎。")
|
||||
model: str = Field(default="PP-OCRv5_mobile", description="模型族标识。")
|
||||
text: str = Field(default="", description="合并后的完整 OCR 文本。")
|
||||
summary: str = Field(default="", description="供对话和语义层复用的简短摘要。")
|
||||
avg_score: float = Field(default=0.0, ge=0.0, le=1.0, description="平均识别置信度。")
|
||||
line_count: int = Field(default=0, ge=0, description="文本行数。")
|
||||
page_count: int = Field(default=1, ge=0, description="识别页数。")
|
||||
warnings: list[str] = Field(default_factory=list, description="该文件的识别提示或警告。")
|
||||
lines: list[OcrRecognizeLineRead] = Field(default_factory=list, description="逐行识别结果。")
|
||||
|
||||
|
||||
class OcrRecognizeBatchRead(BaseModel):
|
||||
engine: str = Field(default="paddleocr_mobile", description="使用的 OCR 引擎。")
|
||||
model: str = Field(default="PP-OCRv5_mobile", description="模型族标识。")
|
||||
total_file_count: int = Field(default=0, ge=0, description="本次上传的总文件数。")
|
||||
success_count: int = Field(default=0, ge=0, description="成功进入 OCR 的文件数。")
|
||||
documents: list[OcrRecognizeDocumentRead] = Field(
|
||||
default_factory=list,
|
||||
description="逐文件 OCR 结果。",
|
||||
)
|
||||
361
server/src/app/services/expense_claims.py
Normal file
361
server/src/app/services/expense_claims.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.employee import Employee
|
||||
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
|
||||
from app.schemas.ontology import OntologyEntity, OntologyParseResult
|
||||
from app.services.audit import AuditLogService
|
||||
from app.services.agent_foundation import AgentFoundationService
|
||||
|
||||
EXPENSE_TYPE_LABELS = {
|
||||
"travel": "差旅",
|
||||
"hotel": "住宿",
|
||||
"transport": "交通",
|
||||
"meal": "餐费",
|
||||
"meeting": "会务",
|
||||
"entertainment": "招待",
|
||||
}
|
||||
|
||||
|
||||
class ExpenseClaimService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
self.audit_service = AuditLogService(db)
|
||||
|
||||
def upsert_draft_from_ontology(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
user_id: str | None,
|
||||
message: str,
|
||||
ontology: OntologyParseResult,
|
||||
context_json: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
self._ensure_ready()
|
||||
|
||||
claim = self._find_target_claim(ontology=ontology, context_json=context_json)
|
||||
before_json = self._serialize_claim(claim) if claim is not None else None
|
||||
|
||||
employee = self._resolve_employee(ontology=ontology, context_json=context_json)
|
||||
amount = self._resolve_amount(ontology.entities)
|
||||
occurred_at = self._resolve_occurred_at(ontology)
|
||||
expense_type = self._resolve_expense_type(ontology.entities)
|
||||
location = self._resolve_location(message=message, context_json=context_json)
|
||||
reason = self._resolve_reason(message=message, context_json=context_json)
|
||||
attachment_count = self._resolve_attachment_count(context_json)
|
||||
|
||||
if claim is None:
|
||||
claim = ExpenseClaim(
|
||||
claim_no=self._generate_claim_no(occurred_at),
|
||||
employee_id=employee.id if employee is not None else None,
|
||||
employee_name=employee.name if employee is not None else self._resolve_employee_name(
|
||||
ontology=ontology,
|
||||
context_json=context_json,
|
||||
user_id=user_id,
|
||||
),
|
||||
department_id=employee.organization_unit_id if employee is not None else None,
|
||||
department_name=self._resolve_department_name(
|
||||
employee=employee,
|
||||
context_json=context_json,
|
||||
),
|
||||
project_code=self._resolve_project_code(ontology.entities),
|
||||
expense_type=expense_type,
|
||||
reason=reason,
|
||||
location=location,
|
||||
amount=amount,
|
||||
currency="CNY",
|
||||
invoice_count=attachment_count,
|
||||
occurred_at=occurred_at,
|
||||
status="draft",
|
||||
approval_stage="待补充",
|
||||
risk_flags_json=list(ontology.risk_flags),
|
||||
)
|
||||
self.db.add(claim)
|
||||
else:
|
||||
claim.employee_id = employee.id if employee is not None else claim.employee_id
|
||||
claim.employee_name = (
|
||||
employee.name
|
||||
if employee is not None
|
||||
else self._resolve_employee_name(
|
||||
ontology=ontology,
|
||||
context_json=context_json,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id
|
||||
claim.department_name = self._resolve_department_name(
|
||||
employee=employee,
|
||||
context_json=context_json,
|
||||
fallback=claim.department_name,
|
||||
)
|
||||
claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code
|
||||
claim.expense_type = expense_type or claim.expense_type
|
||||
claim.reason = reason
|
||||
claim.location = location
|
||||
claim.amount = amount
|
||||
claim.invoice_count = attachment_count
|
||||
claim.occurred_at = occurred_at
|
||||
claim.status = "draft"
|
||||
claim.approval_stage = "待补充"
|
||||
claim.risk_flags_json = list(ontology.risk_flags)
|
||||
|
||||
self.db.flush()
|
||||
self._upsert_primary_item(
|
||||
claim=claim,
|
||||
occurred_at=occurred_at,
|
||||
expense_type=expense_type,
|
||||
amount=amount,
|
||||
reason=reason,
|
||||
location=location,
|
||||
attachment_names=self._resolve_attachment_names(context_json),
|
||||
)
|
||||
self.db.commit()
|
||||
self.db.refresh(claim)
|
||||
|
||||
self.audit_service.log_action(
|
||||
actor=user_id or claim.employee_name or "anonymous",
|
||||
action="expense_claim.draft_upsert",
|
||||
resource_type="expense_claim",
|
||||
resource_id=claim.id,
|
||||
before_json=before_json,
|
||||
after_json=self._serialize_claim(claim),
|
||||
request_id=run_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": (
|
||||
f"已创建报销草稿 {claim.claim_no},当前状态为 draft。"
|
||||
"你可以继续补充费用明细、客户单位和票据附件。"
|
||||
),
|
||||
"draft_only": True,
|
||||
"claim_id": claim.id,
|
||||
"claim_no": claim.claim_no,
|
||||
"status": claim.status,
|
||||
"amount": float(claim.amount),
|
||||
"invoice_count": int(claim.invoice_count or 0),
|
||||
}
|
||||
|
||||
def _find_target_claim(
|
||||
self,
|
||||
*,
|
||||
ontology: OntologyParseResult,
|
||||
context_json: dict[str, Any],
|
||||
) -> ExpenseClaim | None:
|
||||
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
|
||||
if draft_claim_id:
|
||||
return self.db.get(ExpenseClaim, draft_claim_id)
|
||||
|
||||
claim_codes = [
|
||||
item.normalized_value
|
||||
for item in ontology.entities
|
||||
if item.type == "expense_claim" and item.normalized_value
|
||||
]
|
||||
if not claim_codes:
|
||||
return None
|
||||
|
||||
stmt = select(ExpenseClaim).where(ExpenseClaim.claim_no.in_(claim_codes)).limit(1)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def _upsert_primary_item(
|
||||
self,
|
||||
*,
|
||||
claim: ExpenseClaim,
|
||||
occurred_at: datetime,
|
||||
expense_type: str,
|
||||
amount: Decimal,
|
||||
reason: str,
|
||||
location: str,
|
||||
attachment_names: list[str],
|
||||
) -> None:
|
||||
item = claim.items[0] if claim.items else None
|
||||
if item is None:
|
||||
item = ExpenseClaimItem(
|
||||
claim_id=claim.id,
|
||||
item_date=occurred_at.date(),
|
||||
item_type=expense_type,
|
||||
item_reason=reason,
|
||||
item_location=location,
|
||||
item_amount=amount,
|
||||
invoice_id=attachment_names[0] if attachment_names else None,
|
||||
)
|
||||
claim.items.append(item)
|
||||
self.db.add(item)
|
||||
return
|
||||
|
||||
item.item_date = occurred_at.date()
|
||||
item.item_type = expense_type
|
||||
item.item_reason = reason
|
||||
item.item_location = location
|
||||
item.item_amount = amount
|
||||
item.invoice_id = attachment_names[0] if attachment_names else item.invoice_id
|
||||
|
||||
def _generate_claim_no(self, occurred_at: datetime) -> str:
|
||||
month_code = occurred_at.strftime("%Y%m")
|
||||
prefix = f"EXP-{month_code}-"
|
||||
existing = int(
|
||||
self.db.scalar(
|
||||
select(func.count()).select_from(ExpenseClaim).where(ExpenseClaim.claim_no.like(f"{prefix}%"))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return f"{prefix}{existing + 1:03d}"
|
||||
|
||||
def _resolve_employee(
|
||||
self,
|
||||
*,
|
||||
ontology: OntologyParseResult,
|
||||
context_json: dict[str, Any],
|
||||
) -> Employee | None:
|
||||
employee_name = self._resolve_employee_name(
|
||||
ontology=ontology,
|
||||
context_json=context_json,
|
||||
user_id=None,
|
||||
)
|
||||
if not employee_name:
|
||||
return None
|
||||
|
||||
stmt = select(Employee).where(Employee.name == employee_name).limit(1)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_employee_name(
|
||||
*,
|
||||
ontology: OntologyParseResult,
|
||||
context_json: dict[str, Any],
|
||||
user_id: str | None,
|
||||
) -> str:
|
||||
for item in ontology.entities:
|
||||
if item.type == "employee" and item.value.strip():
|
||||
return item.value.strip()
|
||||
for key in ("name", "user_name", "employee_name"):
|
||||
value = str(context_json.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return str(user_id or "待补充").strip() or "待补充"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_department_name(
|
||||
*,
|
||||
employee: Employee | None,
|
||||
context_json: dict[str, Any],
|
||||
fallback: str = "待补充",
|
||||
) -> str:
|
||||
if employee is not None and employee.organization_unit is not None:
|
||||
return employee.organization_unit.name
|
||||
|
||||
request_context = context_json.get("request_context")
|
||||
if isinstance(request_context, dict):
|
||||
for key in ("department", "department_name", "deptName"):
|
||||
value = str(request_context.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
|
||||
for key in ("department_name", "department"):
|
||||
value = str(context_json.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return fallback
|
||||
|
||||
@staticmethod
|
||||
def _resolve_project_code(entities: list[OntologyEntity]) -> str | None:
|
||||
for item in entities:
|
||||
if item.type == "project" and item.normalized_value.strip():
|
||||
return item.normalized_value.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_expense_type(entities: list[OntologyEntity]) -> str:
|
||||
for item in entities:
|
||||
if item.type == "expense_type":
|
||||
normalized = item.normalized_value.strip()
|
||||
if normalized:
|
||||
return normalized
|
||||
return "other"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_reason(*, message: str, context_json: dict[str, Any]) -> str:
|
||||
request_context = context_json.get("request_context")
|
||||
if isinstance(request_context, dict):
|
||||
for key in ("reason", "title"):
|
||||
value = str(request_context.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
return str(message or "").strip()[:500] or "待补充"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str:
|
||||
request_context = context_json.get("request_context")
|
||||
if isinstance(request_context, dict):
|
||||
for key in ("city", "location"):
|
||||
value = str(request_context.get(key) or "").strip()
|
||||
if value:
|
||||
return value
|
||||
compact = str(message or "").replace(" ", "")
|
||||
if "客户现场" in compact:
|
||||
return "客户现场"
|
||||
return "待补充"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_occurred_at(ontology: OntologyParseResult) -> datetime:
|
||||
start_date = ontology.time_range.start_date
|
||||
if start_date:
|
||||
try:
|
||||
parsed = date.fromisoformat(start_date)
|
||||
return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC)
|
||||
except ValueError:
|
||||
pass
|
||||
return datetime.now(UTC)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_amount(entities: list[OntologyEntity]) -> Decimal:
|
||||
for item in entities:
|
||||
if item.type != "amount" or item.role == "threshold":
|
||||
continue
|
||||
try:
|
||||
return Decimal(item.normalized_value).quantize(Decimal("0.01"))
|
||||
except (InvalidOperation, ValueError):
|
||||
continue
|
||||
return Decimal("0.00")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]:
|
||||
names = context_json.get("attachment_names")
|
||||
if not isinstance(names, list):
|
||||
return []
|
||||
return [str(name).strip() for name in names if str(name).strip()]
|
||||
|
||||
def _resolve_attachment_count(self, context_json: dict[str, Any]) -> int:
|
||||
names = self._resolve_attachment_names(context_json)
|
||||
if names:
|
||||
return len(names)
|
||||
try:
|
||||
return max(0, int(context_json.get("attachment_count") or 0))
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _serialize_claim(claim: ExpenseClaim) -> dict[str, Any]:
|
||||
return {
|
||||
"id": claim.id,
|
||||
"claim_no": claim.claim_no,
|
||||
"employee_name": claim.employee_name,
|
||||
"department_name": claim.department_name,
|
||||
"project_code": claim.project_code,
|
||||
"expense_type": claim.expense_type,
|
||||
"reason": claim.reason,
|
||||
"location": claim.location,
|
||||
"amount": float(claim.amount),
|
||||
"invoice_count": int(claim.invoice_count or 0),
|
||||
"status": claim.status,
|
||||
"approval_stage": claim.approval_stage,
|
||||
"risk_flags_json": list(claim.risk_flags_json or []),
|
||||
}
|
||||
|
||||
def _ensure_ready(self) -> None:
|
||||
AgentFoundationService(self.db).ensure_foundation_ready()
|
||||
221
server/src/app/services/ocr.py
Normal file
221
server/src/app/services/ocr.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import SERVER_DIR, get_settings
|
||||
from app.schemas.ocr import OcrRecognizeBatchRead, OcrRecognizeDocumentRead, OcrRecognizeLineRead
|
||||
|
||||
WORKER_JSON_PREFIX = "__OCR_JSON__="
|
||||
SUPPORTED_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".pdf"}
|
||||
|
||||
|
||||
class OcrService:
|
||||
def __init__(self) -> None:
|
||||
self.settings = get_settings()
|
||||
|
||||
def recognize_files(
|
||||
self,
|
||||
files: list[tuple[str, bytes, str | None]],
|
||||
) -> OcrRecognizeBatchRead:
|
||||
if not files:
|
||||
raise ValueError("至少需要上传一个文件。")
|
||||
|
||||
temp_root = self.settings.resolved_ocr_temp_dir
|
||||
temp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
documents: list[OcrRecognizeDocumentRead] = []
|
||||
input_paths: list[Path] = []
|
||||
meta_by_path: dict[str, tuple[str, str]] = {}
|
||||
python_bin = self._resolve_python_bin()
|
||||
worker_path = self._resolve_worker_path()
|
||||
|
||||
try:
|
||||
for filename, content, media_type in files:
|
||||
normalized_name = Path(str(filename or "").strip()).name or "upload.bin"
|
||||
suffix = Path(normalized_name).suffix.lower()
|
||||
resolved_media_type = str(media_type or "application/octet-stream")
|
||||
|
||||
if not content:
|
||||
documents.append(
|
||||
OcrRecognizeDocumentRead(
|
||||
filename=normalized_name,
|
||||
media_type=resolved_media_type,
|
||||
warnings=["文件内容为空,未执行 OCR。"],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if suffix not in SUPPORTED_SUFFIXES:
|
||||
documents.append(
|
||||
OcrRecognizeDocumentRead(
|
||||
filename=normalized_name,
|
||||
media_type=resolved_media_type,
|
||||
warnings=["当前仅支持图片和 PDF 文件进行 OCR。"],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if len(content) > self.settings.ocr_max_file_size_mb * 1024 * 1024:
|
||||
documents.append(
|
||||
OcrRecognizeDocumentRead(
|
||||
filename=normalized_name,
|
||||
media_type=resolved_media_type,
|
||||
warnings=[
|
||||
f"文件超过 {self.settings.ocr_max_file_size_mb} MB,未执行 OCR。"
|
||||
],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
temp_path = temp_root / f"{uuid4().hex}{suffix}"
|
||||
temp_path.write_bytes(content)
|
||||
input_paths.append(temp_path)
|
||||
meta_by_path[str(temp_path)] = (normalized_name, resolved_media_type)
|
||||
|
||||
if input_paths:
|
||||
worker_payload = self._invoke_worker(
|
||||
python_bin=python_bin,
|
||||
worker_path=worker_path,
|
||||
input_paths=input_paths,
|
||||
)
|
||||
for item in worker_payload.get("documents", []):
|
||||
documents.append(self._build_document(item, meta_by_path))
|
||||
|
||||
success_count = sum(
|
||||
1
|
||||
for item in documents
|
||||
if item.line_count > 0 or not item.warnings
|
||||
)
|
||||
engine = (
|
||||
str(worker_payload.get("engine", "paddleocr_mobile"))
|
||||
if input_paths
|
||||
else "paddleocr_mobile"
|
||||
)
|
||||
model = (
|
||||
str(worker_payload.get("model", "PP-OCRv5_mobile"))
|
||||
if input_paths
|
||||
else "PP-OCRv5_mobile"
|
||||
)
|
||||
return OcrRecognizeBatchRead(
|
||||
engine=engine,
|
||||
model=model,
|
||||
total_file_count=len(files),
|
||||
success_count=success_count,
|
||||
documents=documents,
|
||||
)
|
||||
finally:
|
||||
for path in input_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
def _resolve_python_bin(self) -> str:
|
||||
candidates = []
|
||||
configured = str(self.settings.ocr_python_bin or "").strip()
|
||||
if configured:
|
||||
candidates.append(configured)
|
||||
candidates.append(str(SERVER_DIR / ".venv-ocr312" / "bin" / "python"))
|
||||
candidates.append("/usr/local/bin/python3.12")
|
||||
resolved = shutil.which("python3.12")
|
||||
if resolved:
|
||||
candidates.append(resolved)
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate and Path(candidate).exists():
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(
|
||||
"未找到可用的 OCR Python 运行时。请先执行 scripts/bootstrap_paddleocr_mobile.sh "
|
||||
"或通过 OCR_PYTHON_BIN 指向已安装 PaddleOCR 的 Python 3.12。"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_worker_path() -> str:
|
||||
worker_path = SERVER_DIR / "scripts" / "paddle_ocr_worker.py"
|
||||
if not worker_path.exists():
|
||||
raise RuntimeError(f"OCR worker 不存在:{worker_path}")
|
||||
return str(worker_path)
|
||||
|
||||
def _invoke_worker(
|
||||
self,
|
||||
*,
|
||||
python_bin: str,
|
||||
worker_path: str,
|
||||
input_paths: list[Path],
|
||||
) -> dict:
|
||||
command = [
|
||||
python_bin,
|
||||
worker_path,
|
||||
"--lang",
|
||||
self.settings.ocr_language,
|
||||
"--text-detection-model",
|
||||
self.settings.ocr_text_detection_model,
|
||||
"--text-recognition-model",
|
||||
self.settings.ocr_text_recognition_model,
|
||||
]
|
||||
for path in input_paths:
|
||||
command.extend(["--input", str(path)])
|
||||
|
||||
completed = subprocess.run(
|
||||
command,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.settings.ocr_timeout_seconds,
|
||||
check=False,
|
||||
)
|
||||
if completed.returncode != 0:
|
||||
detail = (completed.stderr or completed.stdout or "").strip()
|
||||
raise RuntimeError(f"OCR 执行失败:{detail or 'worker 返回非 0 状态码。'}")
|
||||
|
||||
payload = self._parse_worker_stdout(completed.stdout)
|
||||
if payload is None:
|
||||
raise RuntimeError("OCR worker 未返回可解析的 JSON 结果。")
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _parse_worker_stdout(stdout: str) -> dict | None:
|
||||
for line in reversed(stdout.splitlines()):
|
||||
normalized = line.strip()
|
||||
if normalized.startswith(WORKER_JSON_PREFIX):
|
||||
return json.loads(normalized[len(WORKER_JSON_PREFIX) :])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_document(
|
||||
payload: dict,
|
||||
meta_by_path: dict[str, tuple[str, str]],
|
||||
) -> OcrRecognizeDocumentRead:
|
||||
input_path = str(payload.get("input_path") or "")
|
||||
filename, media_type = meta_by_path.get(
|
||||
input_path,
|
||||
(Path(input_path).name or "upload.bin", "application/octet-stream"),
|
||||
)
|
||||
lines = [
|
||||
OcrRecognizeLineRead(
|
||||
text=str(item.get("text", "")),
|
||||
score=float(item.get("score", 0.0) or 0.0),
|
||||
box=[
|
||||
[int(point[0]), int(point[1])]
|
||||
for point in item.get("box", [])
|
||||
if isinstance(point, list) and len(point) == 2
|
||||
],
|
||||
page_index=int(item["page_index"]) if item.get("page_index") is not None else None,
|
||||
)
|
||||
for item in payload.get("lines", [])
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
return OcrRecognizeDocumentRead(
|
||||
filename=filename,
|
||||
media_type=media_type,
|
||||
engine=str(payload.get("engine", "paddleocr_mobile")),
|
||||
model=str(payload.get("model", "PP-OCRv5_mobile")),
|
||||
text=str(payload.get("text", "")),
|
||||
summary=str(payload.get("summary", "")),
|
||||
avg_score=float(payload.get("avg_score", 0.0) or 0.0),
|
||||
line_count=int(payload.get("line_count", len(lines)) or 0),
|
||||
page_count=int(payload.get("page_count", 1) or 1),
|
||||
warnings=[str(item) for item in payload.get("warnings", [])],
|
||||
lines=lines,
|
||||
)
|
||||
Reference in New Issue
Block a user